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Preface 


In 2012, I published a 1200-page book called Machine Learning: A Probabilistic Perspective, which 
provided a fairly comprehensive coverage of the field of machine learning (ML) at that time, under 
the unifying lens of probabilistic modeling. The book was well received, and won the De Groot prize 
in 2013. 

The year 2012 is also generally considered the start of the “deep learning revolution”. The term 
“deep learning” refers to a branch of ML that is based on neural networks (DNNs), which are nonlinear 
functions with many layers of processing (hence the term “deep”). Although this basic technology had 
been around for many years, it was in 2012 when [KSH12] used DNNs to win the ImageNet image 
classification challenge by such a large margin that it caught the attention of the wider community. 
Related advances on other hard problems, such as speech recognition, appeared around the same time 
(see e.g., [Cir+10; Cir+11; Hin+12]). These breakthroughs were enabled by advances in hardware 
technology (in particular, the repurposing of fast graphics processing units (GPUs) from video games 
to ML), data collection technology (in particular, the use of crowd sourcing tools, such as Amazon’s 
Mechanical Turk platform, to collect large labeled datasets, such as ImageNet), as well as various 
new algorithmic ideas, some of which we cover in this book. 

Since 2012, the field of deep learning has exploded, with new advances coming at an increasing 
pace. Interest in the field has also grown rapidly, fueled by the commercial success of the technology, 
and the breadth of applications to which it can be applied. Therefore, in 2018, I decided to write a 
second edition of my book, to attempt to summarize some of this progress. 

By March 2020, my draft of the second edition had swollen to about 1600 pages, and I still had 
many topics left to cover. As a result, MIT Press told me I would need to split the book into two 
volumes. Then the COVID-19 pandemic struck. I decided to pivot away from book writing, and to 
help develop the risk score algorithm for Google’s exposure notification app [MKS21] as well as to 
assist with various forecasting projects [Wah-+22]. However, by the Fall of 2020, I decided to return 
to working on the book. 

To make up for lost time, I asked several colleagues to help me finish by writing various sections (see 
acknowledgements below). The result of all this is two new books, “Probabilistic Machine Learning: 
An Introduction”, which you are currently reading, and “Probabilistic Machine Learning: Advanced 
Topics”, which is the sequel to this book [Mur23]. Together these two books attempt to present a 
fairly broad coverage of the field of ML c. 2021, using the same unifying lens of probabilistic modeling 
and Bayesian decision theory that I used in the 2012 book. 

Nearly all of the content from the 2012 book has been retained, but it is now split fairly evenly 
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between the two new books. In addition, each new book has lots of fresh material, covering topics from 
deep learning, as well as advances in other parts of the field, such as generative models, variational 
inference and reinforcement learning. 

To make this introductory book more self-contained and useful for students, I have added some 
background material, on topics such as optimization and linear algebra, that was omitted from the 
2012 book due to lack of space. Advanced material, that can be skipped during an introductory 
level course, is denoted by an asterisk * in the section or chapter title. Exercises can be found 
at the end of some chapters. Solutions to exercises marked with an asterisk * are available to 
qualified instructors by contacting MIT Press; solutions to all other exercises can be found online at 
probml.github.io/book1, along with additional teaching material (e.g., figures and slides). 

Another major change is that all of the software now uses Python instead of Matlab. (In the 
future, we may create a Julia version of the code.) The new code leverages standard Python libraries, 
such as NumPy, Scikit-learn, JAX, PyTorch, TensorFlow, PyMC, etc. 

If a figure caption says “Generated by iris_plot.ipynb”, then you can find the corresponding 
Jupyter notebook at probml.github.io/notebooks#iris__plot.ipynb. Clicking on the figure link in the 
pdf version of the book will take you to this list of notebooks. Clicking on the notebook link will 
open it inside Google Colab, which will let you easily reproduce the figure for yourself, and modify 
the underlying source code to gain a deeper understanding of the methods. (Colab gives you access 
to a free GPU, which is useful for some of the more computationally heavy demos.) 
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1 Introduction 


1.1 What is machine learning? 


A popular definition of machine learning or ML, due to Tom Mitchell [Mit97], is as follows: 


A computer program is said to learn from experience E with respect to some class of tasks T, 
and performance measure P, if its performance at tasks in T, as measured by P, improves with 
experience E. 


Thus there are many different kinds of machine learning, depending on the nature of the tasks T we 
wish the system to learn, the nature of the performance measure P we use to evaluate the system, 
and the nature of the training signal or experience E we give it. 

In this book, we will cover the most common types of ML, but from a probabilistic perspective. 
Roughly speaking, this means that we treat all unknown quantities (e.g., predictions about the 
future value of some quantity of interest, such as tomorrow’s temperature, or the parameters of some 
model) as random variables, that are endowed with probability distributions which describe a 
weighted set of possible values the variable may have. (See Chapter 2 for a quick refresher on the 
basics of probability, if necessary.) 

There are two main reasons we adopt a probabilistic approach. First, it is the optimal approach to 
decision making under uncertainty, as we explain in Section 5.1. Second, probabilistic modeling 
is the language used by most other areas of science and engineering, and thus provides a unifying 
framework between these fields. As Shakir Mohamed, a researcher at DeepMind, put it:! 


Almost all of machine learning can be viewed in probabilistic terms, making probabilistic 
thinking fundamental. It is, of course, not the only view. But it is through this view that we 
can connect what we do in machine learning to every other computational science, whether that 
be in stochastic optimisation, control theory, operations research, econometrics, information 
theory, statistical physics or bio-statistics. For this reason alone, mastery of probabilistic 
thinking is essential. 


1.2 Supervised learning 


The most common form of ML is supervised learning. In this problem, the task T is to learn 
a mapping f from inputs x € ¥ to outputs y € V. The inputs æ are also called the features, 


1. Source: Slide 2 of https://bit.ly/3pyHyPn 
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(b) 


Figure 1.1: Three types of Iris flowers: Setosa, Versicolor and Virginica. Used with kind permission of Dennis 
Kramb and SIGNA. 


index sl sw pl pw label 
0 51 35 14 0.2 Setosa 
1 49 3.0 14 0.2 Setosa 


50 7.0 3.2 4.7 1.4 Versicolor 


149 59 3.0 5.1 1.8 Virginica 


Table 1.1: A subset of the Iris design matriz. The features are: sepal length, sepal width, petal length, petal 
width. There are 50 examples of each class. 


covariates, or predictors; this is often a fixed-dimensional vector of numbers, such as the height 
and weight of a person, or the pixels in an image. In this case, ¥ = R?, where D is the dimensionality 
of the vector (i.e., the number of input features). The output y is also known as the label, target, or 
response.” The experience E is given in the form of a set of N input-output pairs D = {(an, yn) }A 1, 
known as the training set. (N is called the sample size.) The performance measure P depends 
on the type of output we are predicting, as we discuss below. 


1.2.1 Classification 


In classification problems, the output space is a set of C unordered and mutually exclusive labels 
known as classes, Y = {1,2,...,C}. The problem of predicting the class label given an input is 
also called pattern recognition. (If there are just two classes, often denoted by y € {0,1} or 
y € {—1, +1}, it is called binary classification.) 


1.2.1.1 Example: classifying Iris flowers 


As an example, consider the problem of classifying Iris flowers into their 3 subspecies, Setosa, 
Versicolor and Virginica. Figure 1.1 shows one example of each of these classes. 


2. Sometimes (e.g., in the statsmodels Python package) æ are called the exogenous variables and y are called the 
endogenous variables. 
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What the computer sees 


82% cat 
x > : 15% dog 
image classification 2% hat 

1% mug 


Figure 1.2: Illustration of the image classification problem. From https: //cs231n. github. io/. Used with 
kind permission of Andrej Karpathy. 


In image classification, the input space ¥ is the set of images, which is a very high-dimensional 
space: for a color image with C = 3 channels (e.g., RGB) and Dı x Dy pixels, we have ¥ = RP, 
where D = C x Dı x Də. (In practice we represent each pixel intensity with an integer, typically from 
the range {0,1,...,255}, but we assume real valued inputs for notational simplicity.) Learning a 
mapping f: Æ — Y from images to labels is quite challenging, as illustrated in Figure 1.2. However, 
it can be tackled using certain kinds of functions, such as a convolutional neural network or 
CNN, which we discuss in Section 14.1. 

Fortunately for us, some botanists have already identified 4 simple, but highly informative, numeric 
features — sepal length, sepal width, petal length, petal width — which can be used to distinguish 
the three kinds of Iris flowers. In this section, we will use this much lower-dimensional input space, 
X = R£, for simplicity. The Iris dataset is a collection of 150 labeled examples of Iris flowers, 50 of 
each type, described by these 4 features. It is widely used as an example, because it is small and 
simple to understand. (We will discuss larger and more complex datasets later in the book.) 

When we have small datasets of features, it is common to store them in an N x D matrix, in which 
each row represents an example, and each column represents a feature. This is known as a design 
matrix; see Table 1.1 for an example.’ 

The Iris dataset is an example of tabular data. When the inputs are of variable size (e.g., 
sequences of words, or social networks), rather than fixed-length vectors, the data is usually stored 


3. This particular design matrix has N = 150 rows and D = 4 columns, and hence has a tall and skinny shape, since 
N > D. By contrast, some datasets (e.g., genomics) have more features than examples, D >> N; their design matrices 
are short and fat. The term “big data” usually means that N is large, whereas the term “wide data” means that 
D is large (relative to N). 
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Figure 1.3: Visualization of the Iris data as a pairwise scatter plot. On the diagonal we plot the marginal 
distribution of each feature for each class. The off-diagonals contain scatterplots of all possible pairs of 
features. Generated by iris_ plot.ipynb 


in some other format rather than in a design matrix. However, such data is often converted to a 
fixed-sized feature representation (a process known as featurization), thus implicitly creating a 
design matrix for further processing. We give an example of this in Section 1.5.4.1, where we discuss 
the “bag of words” representation for sequence data. 


1.2.1.2 Exploratory data analysis 


Before tackling a problem with ML, it is usually a good idea to perform exploratory data analysis, 
to see if there are any obvious patterns (which might give hints on what method to choose), or any 
obvious problems with the data (e.g., label noise or outliers). 

For tabular data with a small number of features, it is common to make a pair plot, in which 
panel (i, j) shows a scatter plot of variables i and j, and the diagonal entries (7,7) show the marginal 
density of variable 7; all plots are optionally color coded by class label — see Figure 1.3 for an 
example. 

For higher-dimensional data, it is common to first perform dimensionality reduction, and then 
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o setosa 
a versicolor 
a virginica 


petal length (cm) <= 2.45 
samples = 150 
value = [50, 50, 50] 
class = setosa 


petal width (cm) 


petal width (cm) <= 1.75 
samples = 100 
value = [0, 50, 50] 
class = versicolor 


0 1 2 3 4 5 6 7 
petal length (cm) 


(a) (b) 


Figure 1.4: Example of a decision tree of depth 2 applied to the Iris data, using just the petal length and petal 
width features. Leaf nodes are color coded according to the predicted class. The number of training samples 
that pass from the root to a node is shown inside each box; we show how many values of each class fall into 
this node. This vector of counts can be normalized to get a distribution over class labels for each node. We 
can then pick the majority class. Adapted from Figures 6.1 and 6.2 of [Gér19]. Generated by iris_ dtree.ipynb. 


to visualize the data in 2d or 3d. We discuss methods for dimensionality reduction in Chapter 20. 


1.2.1.3 Learning a classifier 


From Figure 1.3, we can see that the Setosa class is easy to distinguish from the other two classes. 
For example, suppose we create the following decision rule: 


Setosa if petal length < 2.45 
Versicolor or Virginica otherwise 


f(z;0) = (1.1) 
This is a very simple example of a classifier, in which we have partitioned the input space into two 
regions, defined by the one-dimensional (1d) decision boundary at Zpetal length = 2.45. Points 
lying to the left of this boundary are classified as Setosa; points to the right are either Versicolor or 
Virginica. 

We see that this rule perfectly classifies the Setosa examples, but not the Virginica and Versicolor 
ones. To improve performance, we can recursively partition the space, by splitting regions in which 
the classifier makes errors. For example, we can add another decision rule, to be applied to inputs 
that fail the first test, to check if the petal width is below 1.75cm (in which case we predict Versicolor) 
or above (in which case we predict Virginica). We can arrange these nested rules into a tree structure, 
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Estimate 
Setosa Versicolor Virginica 
Setosa 0 1 1 
Truth | Versicolor 1 0 1 
Virginica 10 10 0 


Table 1.2: Hypothetical asymmetric loss matriz for Iris classification. 


called a decision tree, as shown in Figure 1.4a This induces the 2d decision surface shown in 
Figure 1.4b. 

We can represent the tree by storing, for each internal node, the feature index that is used, as well 
as the corresponding threshold value. We denote all these parameters by 0. We discuss how to 
learn these parameters in Section 18.1. 


1.2.1.4 Empirical risk minimization 


The goal of supervised learning is to automatically come up with classification models such as the 
one shown in Figure 1.4a, so as to reliably predict the labels for any given input. A common way to 
measure performance on this task is in terms of the misclassification rate on the training set: 


N 


XCI (un # f(2n;0)) (1.2) 


n=1 


1 
L(@) + — 

02i 
where I (e) is the binary indicator function, which returns 1 iff (if and only if) the condition e is 
true, and returns 0 otherwise, i.e., 


1 if eis true 
ee { 0 if eis false (1.3) 


This assumes all errors are equal. However it may be the case that some errors are more costly 
than others. For example, suppose we are foraging in the wilderness and we find some Iris flowers. 
Furthermore, suppose that Setosa and Versicolor are tasty, but Virginica is poisonous. In this case, 
we might use the asymmetric loss function ¢(y, 7) shown in Table 1.2. 

We can then define empirical risk to be the average loss of the predictor on the training set: 


1 N 


L(6) = N 5 L(Yn, f(@n; 0)) (1.4) 


n=1 


We see that the misclassification rate Equation (1.2) is equal to the empirical risk when we use 
zero-one loss for comparing the true label with the prediction: 


lo(u, 9) =I (y #9) (1.5) 


See Section 5.1 for more details. 
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One way to define the problem of model fitting or training is to find a setting of the parameters 
that minimizes the empirical risk on the training set: 


N 
6 = argmin L(A) = argmin $ 5 Llyn, f(£n;0)) (1.6) 
0 6 


n=1 


This is called empirical risk minimization. 

However, our true goal is to minimize the expected loss on future data that we have not yet 
seen. That is, we want to generalize, rather than just do well on the training set. We discuss this 
important point in Section 1.2.3. 


1.2.1.5 Uncertainty 


[We must avoid] false confidence bred from an ignorance of the probabilistic nature of the 
world, from a desire to see black and white where we should rightly see gray. — Immanuel 
Kant, as paraphrased by Maria Konnikova [Kon20]. 


In many cases, we will not be able to perfectly predict the exact output given the input, due to 
lack of knowledge of the input-output mapping (this is called epistemic uncertainty or model 
uncertainty), and/or due to intrinsic (irreducible) stochasticity in the mapping (this is called 
aleatoric uncertainty or data uncertainty). 

Representing uncertainty in our prediction can be important for various applications. For example, 
let us return to our poisonous flower example, whose loss matrix is shown in Table 1.2. If we predict 
the flower is Virginica with high probability, then we should not eat the flower. Alternatively, we 
may be able to perform an information gathering action, such as performing a diagnostic test, to 
reduce our uncertainty. For more information about how to make optimal decisions in the presence 
of uncertainty, see Section 5.1. 

We can capture our uncertainty using the following conditional probability distribution: 


ply = c¢|x;0) = fe(x;0) (1.7) 


where f : Æ — [0,1]° maps inputs to a probability distribution over the C possible output labels. 
Since fe(x; 0) returns the probability of class label c, we require 0 < fe < 1 for each c, and Da fea li 
To avoid this restriction, it is common to instead require the model to return unnormalized log- 
probabilities. We can then convert these to probabilities using the softmax function, which is 
defined as follows 


et Ere 


G Lienei G 
vezi ee! em ete 


This maps R© to [0, 1]°, and satisfies the constraints that 0 < softmax(a)e < 1 and DA softmax(a). = 
1. The inputs to the softmax, a = f(x; 0), are called logits. See Section 2.5.2 for details. We thus 
define the overall model as follows: 


softmax(a) + (1.8) 


p(y = c|æ; 0) = softmax,(f (x; @)) (1.9) 
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A common special case of this arises when f is an affine function of the form 
f(@;0) =b+w'a2 =b4+ wz + w£ +++: + WDD (1.10) 


where 0 = (b, w) are the parameters of the model. This model is called logistic regression, and 
will be discussed in more detail in Chapter 10. 

In statistics, the w parameters are usually called regression coefficients (and are typically 
denoted by 8) and b is called the intercept. In ML, the parameters w are called the weights and b 
is called the bias. This terminology arises from electrical engineering, where we view the function f 
as a circuit which takes in x and returns f(æ). Each input is fed to the circuit on “wires”, which 
have weights w. The circuit computes the weighted sum of its inputs, and adds a constant bias or 
offset term b. (This use of the term “bias” should not be confused with the statistical concept of bias 
discussed in Section 4.7.6.1.) 

To reduce notational clutter, it is common to absorb the bias term b into the weights w by defining 
w = [b,w1,..., wp] and defining # = [1, z1,..., £p], so that 


we=bt+w's (1.11) 


This converts the affine function into a linear function. We will usually assume that this has been 
done, so we can just write the prediction function as follows: 


f(a; w) = wx (1.12) 


1.2.1.6 Maximum likelihood estimation 


When fitting probabilistic models, it is common to use the negative log probability as our loss 
function: 


f(y, f(x; @)) = — log pyl f (x; 4)) (1.13) 


The reasons for this are explained in Section 5.1.6.1, but the intuition is that a good model (with low 
loss) is one that assigns a high probability to the true output y for each corresponding input x. The 
average negative log probability of the training set is given by 


N 
NEL(@) = — Y log p(n) (an: )) (1.14) 


This is called the negative log likelihood. If we minimize this, we can compute the maximum 
likelihood estimate or MLE: 


Omie = argmin NLL(6) (1.15) 
0 
This is a very common way to fit models to data, as we will see. 


1.2.2 Regression 


Now suppose that we want to predict a real-valued quantity y € R instead of a class label y € 
{1,...,C}; this is known as regression. For example, in the case of Iris flowers, y might be the 
degree of toxicity if the flower is eaten, or the average height of the plant. 
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Regression is very similar to classification. However, since the output is real-valued, we need to 
use a different loss function. For regression, the most common choice is to use quadratic loss, or £2 
loss: 


lo(y,9) = (y - 9)? (1.16) 


This penalizes large residuals y — 7 more than small ones.* The empirical risk when using quadratic 
loss is equal to the mean squared error or MSE: 


N 
MSE() = + X (on ~ f(@ns8))? (1.17) 


Based on the discussion in Section 1.2.1.5, we should also model the uncertainty in our prediction. 
In regression problems, it is common to assume the output distribution is a Gaussian or normal. 
As we explain in Section 2.6, this distribution is defined by 


1 
V 210? 


where p is the mean, g? is the variance, and V 2702 is the normalization constant needed to ensure 
the density integrates to 1. In the context of regression, we can make the mean depend on the inputs 
by defining u = f(a@,;0). We therefore get the following conditional probability distribution: 


p(yla; 0) = N (yl f(a; 0), 07) (1.19) 


N(ylu, 07) = e7 zaz UH)" (1.18) 


If we assume that the variance o? is fixed (for simplicity), the corresponding average (per-sample) 
negative log likelihood becomes 


(sc) ae (—a5a(m = rien 0)") (1.20) 


1 


1 N 
NLL(9) = -5 S log 
n=1 


We see that the NLL is proportional to the MSE. Hence computing the maximum likelihood estimate 
of the parameters will result in minimizing the squared error, which seems like a sensible approach to 
model fitting. 

1.2.2.1 Linear regression 


As an example of a regression model, consider the 1d data in Figure 1.5a. We can fit this data using 
a simple linear regression model of the form 


f(a;0@) =b+ wr (1.22) 


4. If the data has outliers, the quadratic penalty can be too severe. In such cases, it can be better to use £1 loss 
instead, which is more robust. See Section 11.6 for details. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


10 Chapter 1. Introduction 


10.0 e°, 10.0 4 
75 754 
5.0 5.04 
25 254 
0.0 0.04 
-2.5 -2.54 
-5.0 -5.0 4 
i Sars | T T T T T T T T “Temy T T T T T T T T 
00 25 50 75 100 125 15.0 17.5 20.0 00 25 50 75 100 125 15.0 17.5 20.0 
(a) (b) 


Figure 1.5: (a) Linear regression on some 1d data. (b) The vertical lines denote the residuals between 
the observed output value for each input (blue circle) and its predicted value (red cross). The goal of 
least squares regression is to pick a line that minimizes the sum of squared residuals. Generated by lin- 
reg_ residuals plot.ipynb. 


where w is the slope, b is the offset, and 0 = (w, b) are all the parameters of the model. By adjusting 
0, we can minimize the sum of squared errors, shown by the vertical lines in Figure 1.5b. until we 
find the least squares solution 


6 = argmin MSE(@) (1.23) 
0 
See Section 11.2.2.1 for details. 
If we have multiple input features, we can write 
f(x;0)=b+wizr +: +wprp=b+w'g (1.24) 


where 0 = (w,b). This is called multiple linear regression. 
For example, consider the task of predicting temperature as a function of 2d location in a room. 
Figure 1.6(a) plots the results of a linear model of the following form: 


f(x;0) =b + wizı + were (1.25) 
We can extend this model to use D > 2 input features (such as time of day), but then it becomes 
harder to visualize. 
1.2.2.2 Polynomial regression 


The linear model in Figure 1.5a is obviously not a very good fit to the data. We can improve the 
fit by using a polynomial regression model of degree D. This has the form f(x; w) = w'@(z), 
where (x) is a feature vector derived from the input, which has the following form: 


p(z) = [1, £, £?,..., £P] (1.26) 
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Figure 1.6: Linear and polynomial regression applied to 2d data. Vertical axis is temperature, horizontal 
axes are location within a room. Data was collected by some remote sensing motes at Intel’s lab in Berkeley, 
CA (data courtesy of Romain Thibaux). (a) The fitted plane has the form f(a) = wo + wia1 + were. (b) 
Temperature data is fitted with a quadratic of the form f(a) = wo + W1 £1 + Were 4 wa? | wane. Generated 


by linreg_ 2d_ surface_ demo.ipynb. 


This is a simple example of feature preprocessing, also called feature engineering. 

In Figure 1.7a, we see that using D = 2 results in a much better fit. We can keep increasing D, and 
hence the number of parameters in the model, until D = N — 1; in this case, we have one parameter 
per data point, so we can perfectly interpolate the data. The resulting model will have 0 MSE, as 
shown in Figure 1.7c. However, intuitively the resulting function will not be a good predictor for 
future inputs, since it is too “wiggly”. We discuss this in more detail in Section 1.2.3. 

We can also apply polynomial regression to multi-dimensional inputs. For example, Figure 1.6(b) 
plots the predictions for the temperature model after performing a quadratic expansion of the inputs 


f(x; w) = wo + wizi + wore + wr? + waz? (1.27) 


The quadratic shape is a better fit to the data than the linear model in Figure 1.6(a), since it captures 
the fact that the middle of the room is hotter. We can also add cross terms, such as 7122, to capture 
interaction effects. See Section 1.5.3.2 for details. 

Note that the above models still use a prediction function that is a linear function of the parameters 
w, even though it is a nonlinear function of the original input æ. The reason this is important is 
that a linear model induces an MSE loss function MSE(@) that has a unique global optimum, as we 
explain in Section 11.2.2.1. 


1.2.2.3 Deep neural networks 


In Section 1.2.2.2, we manually specified the transformation of the input features, namely polynomial 
expansion, (#) = [1, £1, £2, x7, 73,...]. We can create much more powerful models by learning to 
do such nonlinear feature extraction automatically. If we let @(a) have its own set of parameters, 
say V, then the overall model has the form 


f(x; w, V) = w' ¢(x; V) (1.28) 
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Figure 1.7: (a-c) Polynomials of degrees 2, 14 and 20 fit to 21 datapoints (the same data as in Figure 1.5). 
(d) MSE vs degree. Generated by linreg_poly_vs_ degree.ipynb. 


We can recursively decompose the feature extractor ¢(x; V) into a composition of simpler functions. 
The resulting model then becomes a stack of L nested functions: 


f(@;@) = fr(fraC-: (fil@))---)) (1.29) 


where f(a) = f(#;0¢) is the function at layer £. The final layer is linear and has the form 
fi(@) = wx, so f(x;0) = wh fi:r-1(®), where fi.1-1(@) = fr_i(--: (fi(@))---) is the learned 
feature extractor. This is the key idea behind deep neural networks or DNNs, which includes 
common variants such as convolutional neural networks (CNNs) for images, and recurrent 
neural networks (RNNs) for sequences. See Part III for details. 


1.2.3 Overfitting and generalization 


We can rewrite the empirical risk in Equation (1.4) in the following equivalent way: 


LO; Duain) == J Uy, Fw; 6) (1.30) 


~~ eet 
| rain| (x,y)€Dtrain 
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where |Dtrain| is the size of the training set Dtrain. This formulation is useful because it makes explicit 
which dataset the loss is being evaluated on. 

With a suitably flexible model, we can drive the training loss to zero (assuming no label noise), by 
simply memorizing the correct output for each input. For example, Figure 1.7(c) perfectly interpolates 
the training data (modulo the last point on the right). But what we care about is prediction accuracy 
on new data, which may not be part of the training set. A model that perfectly fits the training 
data, but which is too complex, is said to suffer from overfitting. 

To detect if a model is overfitting, let us assume (for now) that we have access to the true (but 
unknown) distribution p*(x,y) used to generate the training set. Then, instead of computing the 
empirical risk we compute the theoretical expected loss or population risk 


L(0; p*) = Ep- (æy) (L(y, f(a: 0))] (1.31) 


The difference £(0; p*) — L(0; Dirain) is called the generalization gap. If a model has a large 
generalization gap (i.e., low empirical risk but high population risk), it is a sign that it is overfitting. 

In practice we don’t know p*. However, we can partition the data we do have into two subsets, 
known as the training set and the test set. Then we can approximate the population risk using the 
test risk: 


LO; Dress) —— Y lly, f(e;0)) (1.32) 


[Drest (x,y)EDtest 


As an example, in Figure 1.7d, we plot the training error and test error for polynomial regression 
as a function of degree D. We see that the training error goes to 0 as the model becomes more 
complex. However, the test error has a characteristic U-shaped curve: on the left, where D = 1, 
the model is underfitting; on the right, where D œ 1, the model is overfitting; and when D = 2, 
the model complexity is “just right”. 

How can we pick a model of the right complexity? If we use the training set to evaluate different 
models, we will always pick the most complex model, since that will have the most degrees of 
freedom, and hence will have minimum loss. So instead we should pick the model with minimum 
test loss. 

In practice, we need to partition the data into three sets, namely the training set, the test set and 
a validation set; the latter is used for model selection, and we just use the test set to estimate 
future performance (the population risk), i.e., the test set is not used for model fitting or model 
selection. See Section 4.5.4 for further details. 


1.2.4 No free lunch theorem 


All models are wrong, but some models are useful. — George Box |[BD87, p424].° 


Given the large variety of models in the literature, it is natural to wonder which one is best. 
Unfortunately, there is no single best model that works optimally for all kinds of problems — this 
is sometimes called the no free lunch theorem [Wol96]. The reason is that a set of assumptions 
(also called inductive bias) that works well in one domain may work poorly in another. The best 


5. George Box is a retired statistics professor at the University of Wisconsin. 
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way to pick a suitable model is based on domain knowledge, and/or trial and error (i.e., using model 
selection techniques such as cross validation (Section 4.5.4) or Bayesian methods (Section 5.2.2 and 
Section 5.2.6). For this reason, it is important to have many models and algorithmic techniques in 
one’s toolbox to choose from. 


1.3 Unsupervised learning 


In supervised learning, we assume that each input example æ in the training set has an associated 
set of output targets y, and our goal is to learn the input-output mapping. Although this is useful, 
and can be difficult, supervised learning is essentially just “glorified curve fitting” [Pea18]. 

An arguably much more interesting task is to try to “make sense of” data, as opposed to just 
learning a mapping. That is, we just get observed “inputs” D = {£n : n = 1: N} without any 
corresponding “outputs” Yn. This is called unsupervised learning. 

From a probabilistic perspective, we can view the task of unsupervised learning as fitting an 
unconditional model of the form p(a), which can generate new data x, whereas supervised learning 
involves fitting a conditional model, p(y|a), which specifies (a distribution over) outputs given 
inputs.° 

Unsupervised learning avoids the need to collect large labeled datasets for training, which can 
often be time consuming and expensive (think of asking doctors to label medical images). 

Unsupervised learning also avoids the need to learn how to partition the world into often arbitrary 
categories. For example, consider the task of labeling when an action, such as “drinking” or “sipping”, 
occurs in a video. Is it when the person picks up the glass, or when the glass first touches the mouth, 
or when the liquid pours out? What if they pour out some liquid, then pause, then pour again — is 
that two actions or one? Humans will often disagree on such issues [Idr+17], which means the task is 
not well defined. It is therefore not reasonable to expect machines to learn such mappings.” 

Finally, unsupervised learning forces the model to “explain” the high-dimensional inputs, rather 
than just the low-dimensional outputs. This allows us to learn richer models of “how the world works”. 
As Geoff Hinton, who is a famous professor of ML at the University of Toronto, has said: 


When we’re learning to see, nobody’s telling us what the right answers are — we just look. 
Every so often, your mother says “that’s a dog”, but that’s very little information. You’d be 
lucky if you got a few bits of information — even one bit per second — that way. The brain’s 
visual system has O(10'*) neural connections. And you only live for O(10°) seconds. So it’s no 
use learning one bit per second. You need more like O(10°) bits per second. And there’s only 
one place you can get that much information: from the input itself. — Geoffrey Hinton, 1996 
(quoted in [Gor06]). 


1.3.1 Clustering 


A simple example of unsupervised learning is the problem of finding clusters in data. The goal is to 
partition the input into regions that contain “similar” points. As an example, consider a 2d version 


6. In the statistics community, it is common to use Œ% to denote exogenous variables that are not modeled, but are 
simply given as inputs. Therefore an unconditional model would be denoted p(y) rather than p(x). 

7. A more reasonable approach is to try to capture the probability distribution over labels produced by a “crowd” of 
annotators (see e.g., [Dum+18; Aro+19]). This embraces the fact that there can be multiple “correct” labels for a 
given input due to the ambiguity of the task itself. 
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Figure 1.8: (a) A scatterplot of the petal features from the iris dataset. (b) The result of unsupervised 
clustering using K = 3. Generated by iris_ kmeans.ipynb. 
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Figure 1.9: (a) Scatterplot of iris data (first 3 features). Points are color coded by class. (b) We fit a 2d 
linear subspace to the 3d data using PCA. The class labels are ignored. Red dots are the original data, black 
dots are points generated from the model using & = Wz + u, where z are latent points on the underlying 
inferred 2d linear manifold. Generated by iris_ pca.ipynb. 


of the Iris dataset. In Figure 1.8a, we show the points without any class labels. Intuitively there 
are at least two clusters in the data, one in the bottom left and one in the top right. Furthermore, 
if we assume that a “good” set of clusters should be fairly compact, then we might want to split 
the top right into (at least) two subclusters. The resulting partition into three clusters is shown 
in Figure 1.8b. (Note that there is no correct number of clusters; instead, we need to consider the 
tradeoff between model complexity and fit to the data. We discuss ways to make this tradeoff in 
Section 21.3.7.) 


1.3.2 Discovering latent “factors of variation” 


When dealing with high-dimensional data, it is often useful to reduce the dimensionality by projecting 
it to a lower dimensional subspace which captures the “essence” of the data. One approach to this 
problem is to assume that each observed high-dimensional output x, € RP was generated by a set 
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of hidden or unobserved low-dimensional latent factors zn € R*. We can represent the model 
diagrammatically as follows: Zn —> Œn, where the arrow represents causation. Since we don’t know 
the latent factors Zn, we often assume a simple prior probability model for p(z,,) such as a Gaussian, 
which says that each factor is a random K-dimensional vector. If the data is real-valued, we can use 
a Gaussian likelihood as well. 

The simplest example is when we use a linear model, p(£n|zn; 0) = N(a@n|W2n + u, ©). The 
resulting model is called factor analysis (FA). It is similar to linear regression, except we only 
observe the outputs £n, and not the inputs Zn. In the special case that © = o7I, this reduces to 
a model called probabilistic principal components analysis (PCA), which we will explain in 
Section 20.1. In Figure 1.9, we give an illustration of how this method can find a 2d linear subspace 
when applied to some simple 3d data. 

Of course, assuming a linear mapping from Zn to £n is very restrictive. However, we can create 
nonlinear extensions by defining p(£n|zn;0) = N(an|f(2n39), 07D), where f(z;@) is a nonlinear 
model, such as a deep neural network. It becomes much harder to fit such a model (i.e., to estimate the 
parameters 0), because the inputs to the neural net have to be inferred, as well as the parameters of 
the model. However, there are various approximate methods, such as the variational autoencoder 
which can be applied (see Section 20.3.5). 


1.3.3 Self-supervised learning 


A recently popular approach to unsupervised learning is known as self-supervised learning. In this 
approach, we create proxy supervised tasks from unlabeled data. For example, we might try to learn 
to predict a color image from a grayscale image, or to mask out words in a sentence and then try to 
predict them given the surrounding context. The hope is that the resulting predictor #, = f(a2;6), 
where 22 is the observed input and 2, is the predicted output, will learn useful features from the 
data, that can then be used in standard, downstream supervised tasks. This avoids the hard problem 
of trying to infer the “true latent factors” z behind the observed data, and instead relies on standard 
supervised learning methods. We discuss this approach in more detail in Section 19.2. 


1.3.4 Evaluating unsupervised learning 


Although unsupervised learning is appealing, it is very hard to evaluate the quality of the output of 
an unsupervised learning method, because there is no ground truth to compare to [TOB16]. 

A common method for evaluating unsupervised models is to measure the probability assigned by 
the model to unseen test examples. We can do this by computing the (unconditional) negative log 
likelihood of the data: 


L(0;D) = 5 > log p(a|6) (1.33) 


This treats the problem of unsupervised learning as one of density estimation. The idea is that a 
good model will not be “surprised” by actual data samples (i.e., will assign them high probability). 
Furthermore, since probabilities must sum to 1.0, if the model assigns high probability to regions of 
data space where the data samples come from, it implicitly assigns low probability to the regions 
where the data does not come from. Thus the model has learned to capture the typical patterns 
in the data. This can be used inside of a data compression algorithm. 
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(a) 


Figure 1.10: Examples of some control problems. (a) Space Invaders Atari game. From https: // gymnasium. 
farama. org/ environments/atari/space_inuaders/. (b) Controlling a humanoid robot in the MuJuCo 
simulator so it walks as fast as possible without falling over. From https: // gymnasium. farama. org/ 
environments/mujoco/humanoid/. 


Unfortunately, density estimation is difficult, especially in high dimensions. Furthermore, a model 
that assigns high probability to the data may not have learned useful high-level patterns (after all, 
the model could just memorize all the training examples). 

An alternative evaluation metric is to use the learned unsupervised representation as features or 
input to a downstream supervised learning method. If the unsupervised method has discovered useful 
patterns, then it should be possible to use these patterns to perform supervised learning using much 
less labeled data than when working with the original features. For example, in Section 1.2.1.1, we 
saw how the 4 manually defined features of iris flowers contained most of the information needed 
to perform classification. We were thus able to train a classifier with nearly perfect performance 
using just 150 examples. If the input was raw pixels, we would need many more examples to achieve 
comparable performance (see Section 14.1). That is, we can increase the sample efficiency of 
learning (i.e., reduce the number of labeled examples needed to get good performance) by first 
learning a good representation. 

Increased sample efficiency is a useful evaluation metric, but in many applications, especially in 
science, the goal of unsupervised learning is to gain understanding, not to improve performance on 
some prediction task. This requires the use of models that are interpretable, but which can also 
generate or “explain” most of the observed patterns in the data. To paraphrase Plato, the goal is 
to discover how to “carve nature at its joints”. Of course, evaluating whether we have successfully 
discovered the true underlying structure behind some dataset often requires performing experiments 
and thus interacting with the world. We discuss this topic further in Section 1.4. 


1.4 Reinforcement learning 


In addition to supervised and unsupervised learning, there is a third kind of ML known as reinforce- 
ment learning (RL). In this class of problems, the system or agent has to learn how to interact 
with its environment. This can be encoded by means of a policy a = (a), which specifies which 
action to take in response to each possible input x (derived from the environment state). 

For example, consider an agent that learns to play a video game, such as Atari Space Invaders (see 
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@ “Pure” Reinforcement Learning (cherry) 
> The machine predicts a scalar 
reward given once in a while. 

> A few bits for some samples 


@ Supervised Learning (icing) 
> The machine predicts a category 
or a few numbers for each input 
> Predicting human-supplied data 

> 10-10,000 bits per sample 


@ Unsupervised/Predictive Learning (cake) 
> The machine predicts any part of 
its input for any observed part. 

> Predicts future frames in videos 

» Millions of bits per sample 


Figure 1.11: The three types of machine learning visualized as layers of a chocolate cake. This figure (originally 
from https: //bit. ly/ 2m65Vs1 ) was used in a talk by Yann LeCun at NIPS’16, and is used with his kind 
permission. 


Figure 1.10a). In this case, the input x is the image (or sequence of past images), and the output a 
is the direction to move in (left or right) and whether to fire a missile or not. As a more complex 
example, consider the problem of a robot learning to walk (see Figure 1.10b). In this case, the input 
x is the set of joint positions and angles for all the limbs, and the output a is a set of actuation or 
motor control signals. 

The difference from supervised learning (SL) is that the system is not told which action is the 
best one to take (i.e., which output to produce for a given input). Instead, the system just receives 
an occasional reward (or punishment) signal in response to the actions that it takes. This is like 
learning with a critic, who gives an occasional thumbs up or thumbs down, as opposed to learning 
with a teacher, who tells you what to do at each step. 

RL has grown in popularity recently, due to its broad applicability (since the reward signal that 
the agent is trying to optimize can be any metric of interest). However, it can be harder to make RL 
work than it is for supervised or unsupervised learning, for a variety of reasons. A key difficulty is 
that the reward signal may only be given occasionally (e.g., if the agent eventually reaches a desired 
state), and even then it may be unclear to the agent which of its many actions were responsible for 
getting the reward. (Think of playing a game like chess, where there is a single win or lose signal at 
the end of the game.) 

To compensate for the minimal amount of information coming from the reward signal, it is common 
to use other information sources, such as expert demonstrations, which can be used in a supervised 
way, or unlabeled data, which can be used by an unsupervised learning system to discover the 
underlying structure of the environment. This can make it feasible to learn from a limited number of 
trials (interactions with the environment). As Yann LeCun put it, in an invited talk at the NIPS® 
conference in 2016: “If intelligence was a cake, unsupervised learning would be the chocolate sponge, 
supervised learning would be the icing, and reinforcement learning would be the cherry.” This is 
illustrated in Figure 1.11. 


8. NIPS stands for “Neural Information Processing Systems”. It is one of the premier ML conferences. It has recently 
been renamed to NeurIPS. 
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Figure 1.12: (a) Visualization of the MNIST dataset. Each image is 28 x 28. There are 60k training examples 
and 10k test examples. We show the first 25 images from the training set. Generated by mnist_viz_tf.ipynb. 
(b) Visualization of the EMNIST dataset. There are 697,982 training examples, and 116,323 test examples, 
each of size 28 x 28. There are 62 classes (a-z, A-Z, 0-9). We show the first 25 images from the training set. 
Generated by emnist_viz_jax.ipynb. 
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More information on RL can be found in the sequel to this book, [Mur23]. 


1.5 Data 


Machine learning is concerned with fitting models to data using various algorithms. Although we 
focus on the modeling and algorithm aspects, it is important to mention that the nature and quality 
of the training data also plays a vital role in the success of any learned model. 

In this section, we briefly describe some common image and text datasets that we will use in this 
book. We also briefly discuss the topic of data preprocessing. 


1.5.1 Some common image datasets 


In this section, we briefly discuss some image datasets that we will use in this book. 


1.5.1.1 Small image datasets 


One of the simplest and most widely used is known as MNIST [LeC+98; YB19].° This is a dataset 
of 60k training images and 10k test images, each of size 28 x 28 (grayscale), illustrating handwritten 
digits from 10 categories. Each pixel is an integer in the range {0,1,...,255}; these are usually 
rescaled to [0,1], to represent pixel intensity. We can optionally convert this to a binary image by 
thresholding. See Figure 1.12a for an illustration. 


9. The term “MNIST” stands for “Modified National Institute of Standards”; The term “modified” is used because the 
images have been preprocessed to ensure the digits are mostly in the center of the image. 
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Figure 1.13: (a) Visualization of the Fashion-MNIST dataset [XRV17]. The dataset has the same size 
as MNIST, but is harder to classify. There are 10 classes: T-shirt/top, Trouser, Pullover, Dress, Coat, 
Sandal, Shirt, Sneaker, Bag, Ankle-boot. We show the first 25 images from the training set. Generated by 
fashion_viz_tf.ipynb. (b) Some images from the CIFAR-10 dataset [KH09]. Each image is 32 x 32 x 3, where 
the final dimension of size 3 refers to RGB. There are 50k training examples and 10k test examples. There 
are 10 classes: plane, car, bird, cat, deer, dog, frog, horse, ship, and truck. We show the first 25 images from 
the training set. Generated by cifar_viz_tf.ipynb. 


MNIST is so widely used in the ML community that Geoff Hinton, a famous ML researcher, has 
called it the “drosophila of machine learning”, since if we cannot make a method work well on MNIST, 
it will likely not work well on harder datasets. However, nowadays MNIST classification is considered 
“too easy”, since it is possible to distinguish most pairs of digits by looking at just a single pixel. 
Various extensions have been proposed. 

In [Coh+17], they proposed EMNIST (extended MNIST), that also includes lower and upper 
case letters. See Figure 1.12b for a visualization. This dataset is much harder than MNIST, since 
there are 62 classes, several of which are quite ambiguous (e.g., the digit 1 vs the lower case letter 1). 

In [XRV17], they proposed Fashion-MNIST, which has exactly the same size and shape as 
MNIST, but where each image is the picture of a piece of clothing instead of a handwritten digit. 
See Figure 1.13a for a visualization. 

For small color images, the most common dataset is CIFAR [KH09].'° This is a dataset of 60k 
images, each of size 32 x 32 x 3, representing everyday objects from 10 or 100 classes; see Figure 1.13b 
for an illustration.!! 
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Figure 1.14: (a) Sample images from the ImageNet dataset [Rus+15]. This subset consists of 1.39M color 
training images, each of which is 256 x 256 pixels in size. There are 1000 possible labels, one per image, and 
the task is to minimize the top-5 error rate, i.e., to ensure the correct label is within the 5 most probable 
predictions. Below each image we show the true label, and a distribution over the top 5 predicted labels. If the 
true label is in the top 5, its probability bar is colored red. Predictions are generated by a convolutional neural 
network (CNN) called “AlexNet” (Section 14.3.2). From Figure 4 of [KSH12]. Used with kind permission of 
Alex Krizhevusky. (b) Misclassification rate (top 5) on the ImageNet competition over time. Used with kind 
permission of Andrej Karpathy. 


1.5.1.2 ImageNet 


Small datasets are useful for prototyping ideas, but it is also important to test methods on larger 
datasets, both in terms of image size and number of labeled examples. The most widely used dataset 
of this type is called ImageNet [Rus+15]. This is a dataset of ~ 14M images of size 256 x 256 x 3 
illustrating various objects from 20,000 classes; see Figure 1.14a for some examples. 

The ImageNet dataset was used as the basis of the ImageNet Large Scale Visual Recognition 
Challenge (ILSVRC), which ran from 2010 to 2018. This used a subset of 1.3M images from 1000 
classes. During the course of the competition, significant progress was made by the community, as 
shown in Figure 1.14b. In particular, 2015 marked the first year in which CNNs could outperform 
humans (or at least one human, namely Andrej Karpathy) at the task of classifying images from 
ImageNet. Note that this does not mean that CNNs are better at vision than humans (see e.g., 
[YL21] for some common failure modes). Instead, it mostly likely reflects the fact that the dataset 
makes many fine-grained classification distinctions — such as between a “tiger” and a “tiger cat” 

— that humans find difficult to understand; by contrast, sufficiently flexible CNNs can learn arbitrary 
patterns, including random labels [Zha+ 17a]. 


10. CIFAR stands for “Canadian Institute For Advanced Research”. This is the agency that funded labeling of 
the dataset, which was derived from the Tinylmages dataset at http://groups.csail.mit.edu/vision/TinyImages/ 
created by Antonio Torralba. See [KH09] for details. 

11. Despite its popularity, the CIFAR dataset has some issues. For example, the base error on CIFAR-100 is 5.85% due 
to mislabeling [NAM21]. This makes any results with accuracy above 94.15% acc suspicious. Also, 10% of CIFAR-100 
training set images are duplicated in the test set [BD20]. 
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1. this film was just brilliant casting location scenery story direction everyone's really suited the part they played robert 
<UNK> is an amazing actor ... 

2. big hair big boobs bad music and a giant safety pin these are the words to best describe this terrible movie i love cheesy 
horror movies and i’ve seen hundreds... 


Table 1.3: We show snippets of the first two sentences from the IMDB movie review dataset. The first example 
is labeled positive and the second negative. (<UNK> refers to an unknown token.) 


Although ImageNet is much harder than MNIST and CIFAR as a classification benchmark, it too 
is almost “saturated” [Bey-+20]. Nevertheless, relative performance of methods on ImageNet is often 
a surprisingly good predictor of performance on other, unrelated image classification tasks (see e.g., 
[Rec+19]), so it remains very widely used. 


1.5.2 Some common text datasets 


Machine learning is often applied to text to solve a variety of tasks. This is known as natural 
language processing or NLP (sce e.g., |JM20] for details). Below we briefly mention a few text 
datasets that we will use in this book. 


1.5.2.1 Text classification 


A simple NLP task is text classification, which can be used for email spam classification, senti- 
ment analysis (e.g., is a movie or product review positive or negative), etc. A common dataset for 
evaluating such methods is the IMDB movie review dataset from [Maa+11]. (IMDB stands for 
“Internet Movie Database”.) This contains 25k labeled examples for training, and 25k for testing. 
Each example has a binary label, representing a positive or negative rating. See Table 1.3 for some 
example sentences. 


1.5.2.2 Machine translation 


A more difficult NLP task is to learn to map a sentence æ in one language to a “semantically equivalent” 
sentence y in another language; this is called machine translation. Training such models requires 
aligned (x,y) pairs. Fortunately, several such datasets exist, e.g., from the Canadian parliament 
(English-French pairs), and the European Union (Europarl). A subset of the latter, known as the 
WMT dataset (Workshop on Machine Translation), consists of English-German pairs, and is widely 
used as a benchmark dataset. 


1.5.2.3 Other seq2seq tasks 


A generalization of machine translation is to learn a mapping from one sequence x to any other 
sequence y. This is called a seq2seq model, and can be viewed as a form of high-dimensional 
classification (see Section 15.2.3 for details). This framing of the problem is very general, and 
includes many tasks, such as document summarization, question answering, etc. For example, 
Table 1.4 shows how to formulate question answering as a seq2seq problem: the input is the text T 
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T: In meteorology, precipitation is any product of the condensation of atmospheric water vapor that falls under gravity. The 
main forms of precipitation include drizzle, rain, sleet, snow, graupel and hail... Precipitation forms as smaller droplets 
coalesce via collision with other rain drops or ice crystals within a cloud. Short, intense periods of rain in scattered 
locations are called “showers”. 


Q1: What causes precipitation to fall? A1: gravity 
Q2: What is another main form of precipitation besides drizzle, rain, snow, sleet and hail? A2: graupel 
Q3: Where do water droplets collide with ice crystals to form precipitation? A3: within a cloud 


Table 1.4: Question-answer pairs for a sample passage in the SQuAD dataset. Each of the answers is a 
segment of text from the passage. This can be solved using sentence pair tagging. The input is the paragraph 
text T and the question Q. The output is a tagging of the relevant words in T that answer the question in Q. 
From Figure 1 of [Raj+16]. Used with kind permission of Percy Liang. 


and question Q, and the output is the answer A, which is a set of words, possibly extracted from the 
input. 


1.5.2.4 Language modeling 


The rather grandiose term “language modeling” refers to the task of creating unconditional 
generative models of text sequences, p(x1,..., £r). This only requires input sentences x, without 
any corresponding “labels” y. We can therefore think of this as a form of unsupervised learning, 
which we discuss in Section 1.3. If the language model generates output in response to an input, as 
in seq2seq, we can regard it as a conditional generative model. 


1.5.3 Preprocessing discrete input data 


Many ML models assume that the data consists of real-valued feature vectors, x € RP. However, 
sometimes the input may have discrete input features, such as categorical variables like race and 
gender, or words from some vocabulary. In the sections below, we discuss some ways to preprocess 
such data to convert it to vector form. This is a common operation that is used for many different 
kinds of models. 


1.5.3.1 One-hot encoding 


When we have categorical features, we need to convert them to a numerical scale, so that computing 
weighted combinations of the inputs makes sense. The standard way to preprocess such categorical 
variables is to use a one-hot encoding, also called a dummy encoding. If a variable x has K 
values, we will denote its dummy encoding as follows: one-hot(x) = [I (x = 1),...,1 (x = K)]. For 
example, if there are 3 colors (say red, green and blue), the corresponding one-hot vectors will be 
one-hot(red) = [1,0,0], one-hot(green) = [0, 1,0], and one-hot(blue) = [0, 0, 1]. 


1.5.3.2 Feature crosses 


A linear model using a dummy encoding for each categorical variable can capture the main effects 
of each variable, but cannot capture interaction effects between them. For example, suppose we 
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want to predict the fuel efficiency of a vehicle given two categorical input variables: the type (say 
SUV, Truck, or Family car), and the country of origin (say USA or Japan). If we concatenate the 
one-hot encodings for the ternary and binary features, we get the following input encoding: 


d(x) = [1,1 (xı = S$), I (xı =T) 1 (xı = F),I (x2 =U), 1 (x2 = J) (1.34) 


where 2, is the type and 2 is the country of origin. 

This model cannot capture dependencies between the features. For example, we expect trucks to 
be less fuel efficient, but perhaps trucks from the USA are even less efficient than trucks from Japan. 
This cannot be captured using the linear model in Equation (1.34) since the contribution from the 
country of origin is independent of the car type. 

We can fix this by computing explicit feature crosses. For example, we can define a new composite 
feature with 3 x 2 possible values, to capture the interaction of type and country of origin. The new 
model becomes 


fw; w) = wole) (1.35) 
= wo + wil (z1 = S) + wol (£1 = T) + wI (zı = F) 
+ wal (ao = U) + ws5I (z2 = J) 
+ wel (xı = S, z2 =U) + wyl (zı = T, z2 = U) + wel (xı = F, z2 = U) 
+ wgl (xı = S, z2 = J) + wiol (x1 = T, x2 = J) + wil (£1 = F, z2 = J) (1.36) 


We can see that the use of feature crosses converts the original dataset into a wide format, with 
many more columns. 


1.5.4 Preprocessing text data 


In Section 1.5.2, we briefly discussed text classification and other NLP tasks. To feed text data into 
a classifier, we need to tackle various issues. First, documents have a variable length, and are thus 
not fixed-length feature vectors, as assumed by many kinds of models. Second, words are categorical 
variables with many possible values (equal to the size of the vocabulary), so the corresponding 
one-hot encodings will be very high-dimensional, with no natural notion of similarity. Third, we may 
encounter words at test time that have not been seen during training (so-called out-of-vocabulary 
or OOV words). We discuss some solutions to these problems below. More details can be found in 
e.g., [BKL10; MRS08; JM20]. 


1.5.4.1 Bag of words model 


A simple approach to dealing with variable-length text documents is to interpret them as a bag of 
words, in which we ignore word order. To convert this to a vector from a fixed input space, we first 
map each word to a token from some vocabulary. 

To reduce the number of tokens, we often use various pre-processing techniques such as the following: 
dropping punctuation, converting all words to lower case; dropping common but uninformative words, 
such as “and” and “the” (this is called stop word removal); replacing words with their base form, 
such as replacing “running” and “runs” with “run” (this is called word stemming); etc. For details, 
see e.g., [BL12], and for some sample code, see text_preproc_jax.ipynb. 
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Figure 1.15: Example of a term-document matriz, where raw counts have been replaced by their TF-IDF 
values (see Section 1.5.4.2). Darker cells are larger values. From https: //bit. Ly/2kByLQI. Used with 
kind permission of Christoph Carl Kling. 


Let zy; be the token at location t in the n’th document. If there are D unique tokens in the 
vocabulary, then we can represent the n’th document as a D-dimensional vector Zn, where čno is 
the number of times that word v occurs in document n: 


T 
Ëno = > I (Em =v) (1.37) 


where T is the length of document n. We can now interpret documents as vectors in R?. This is 
called the vector space model of text [SWY75; TP10]. 

We traditionally store input data in an N x D design matrix denoted by X, where D is the number 
of features. In the context of vector space models, it is more common to represent the input data 
as a D x N term frequency matrix, where TF;,; is the frequency of term 7 in document j. See 
Figure 1.15 for an illustration. 


1.5.4.2 TF-IDF 


One problem with representing documents as word count vectors is that frequent words may have 
undue influence, just because the magnitude of their word count is higher, even if they do not carry 
much semantic content. A common solution to this is to transform the counts by taking logs, which 
reduces the impact of words that occur many times within a single document. 

To reduce the impact of words that occur many times in general (across all documents), we compute 
a quantity called the inverse document frequency, defined as follows: IDF; £ log DF where 
DF; is the number of documents with term 7. We can combine these transformations to compute the 
TF-IDF matrix as follows: 


TFIDF;; = log(TF;; + 1) x IDF; (1.38) 


(We often normalize each row as well.) This provides a more meaningful representation of documents, 
and can be used as input to many ML algorithms. See tfidf_demo.ipynb for an example. 
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1.5.4.3 Word embeddings 


Although the TF-IDF transformation improves on raw count vectors by placing more weight on 
“informative” words and less on “uninformative” words, it does not solve the fundamental problem with 
the one-hot encoding (from which count vectors are derived), which is that that semantically similar 
words, such as “man” and “woman”, may be not be any closer (in vector space) than semantically 
dissimilar words, such as “man” and “banana”. Thus the assumption that points that are close in 
input space should have similar outputs, which is implicitly made by most prediction models, is 
invalid. 

The standard way to solve this problem is to use word embeddings, in which we map each sparse 
one-hot vector, £n E€ {0,1}”, to a lower-dimensional dense vector, e,; E€ RË using en, = Ent, 
where E € R**" is learned such that semantically similar words are placed close by. There are many 
ways to learn such embeddings, as we discuss in Section 20.5. 

Once we have an embedding matrix, we can represent a variable-length text document as a bag of 
word embeddings. We can then convert this to a fixed length vector by summing (or averaging) 
the embeddings: 


T 
en = 5 ent = Ex, (1.39) 
t=1 


where &,, is the bag of words representation from Equation (1.37). We can then use this inside of a 
logistic regression classifier, which we briefly introduced in Section 1.2.1.5. The overall model has the 
form 


ply = clan, 0) = softmax,.(WEZ,, ) (1.40) 


We often use a pre-trained word embedding matrix E, in which case the model is linear in W, 
which simplifies parameter estimation (see Chapter 10). See also Section 15.7 for a discussion of 
contextual word embeddings. 


1.5.4.4 Dealing with novel words 


At test time, the model may encounter a completely novel word that it has not seen before. This is 
known as the out of vocabulary or OOV problem. Such novel words are bound to occur, because 
the set of words is an open class. For example, the set of proper nouns (names of people and places) 
is unbounded. 

A standard heuristic to solve this problem is to replace all novel words with the special symbol 
UNK, which stands for “unknown”. However, this loses information. For example, if we encounter 
the word “athazagoraphobia”, we may guess it means “fear of something’, since phobia is a common 
suffix in English (derived from Greek) to mean “fear of”. (It turns out that athazagoraphobia means 
“fear of being forgotten about or ignored”.) 

We could work at the character level, but this would require the model to learn how to group 
common letter combinations together into words. It is better to leverage the fact that words have 
substructure, and then to take as input subword units or wordpieces [SHB16; Wu-+16]; these 
are often created using a method called byte-pair encoding [Gag94], which is a form of data 
compression that creates new symbols to represent common substrings. 
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1.5.5 Handling missing data 


Sometimes we may have missing data, in which parts of the input x or output y may be unknown. 
If the output is unknown during training, the example is unlabeled; we consider such semi-supervised 
learning scenarios in Section 19.3. We therefore focus on the case where some of the input features 
may be missing, either at training or testing time, or both. 

To model this, let M be an N x D matrix of binary variables, where Mna = 1 if feature d in 
example n is missing, and Mna = 0 otherwise. Let X, be the visible parts of the input feature matrix, 
corresponding to Mna = 0, and Xp be the missing parts, corresponding to Mna = 1. Let Y be the 
output label matrix, which we assume is fully observed. If we assume p(M|X,, X}, Y) = p(M), we 
say the data is missing completely at random or MCAR, since the missingness does not depend 
on the hidden or observed features. If we assume p(M|X,, X}, Y) = p(M|X,, Y), we say the data is 
missing at random or MAR, since the missingness does not depend on the hidden features, but 
may depend on the visible features. If neither of these assumptions hold, we say the data is not 
missing at random or NMAR. 

In the MCAR and MAR cases, we can ignore the missingness mechanism, since it tells us nothing 
about the hidden features. However, in the NMAR case, we need to model the missing data 
mechanism, since the lack of information may be informative. For example, the fact that someone 
did not fill out an answer to a sensitive question on a survey (e.g., “Do you have COVID?”) could be 
informative about the underlying value. See e.g., [LR87; Mar08] for more information on missing 
data models. 

In this book, we will always make the MAR assumption. However, even with this assumption, we 
cannot directly use a discriminative model, such as a DNN, when we have missing input features, 
since the input x will have some unknown values. 

A common heuristic is called mean value imputation, in which missing values are replaced by 
their empirical mean. More generally, we can fit a generative model to the input, and use that to fill 
in the missing values. We briefly discuss some suitable generative models for this task in Chapter 20, 
and in more detail in the sequel to this book, [Mur23]. 


1.6 Discussion 


In this section, we situate ML and this book into a larger context. 


1.6.1 The relationship between ML and other fields 


There are several subcommunities that work on ML-related topics, each of which have different names. 
The field of predictive analytics is similar to supervised learning (in particular, classification 
and regression), but focuses more on business applications. Data mining covers both supervised 
and unsupervised machine learning, but focuses more on structured data, usually stored in large 
commercial databases. Data science uses techniques from machine learning and statistics, but 
also emphasizes other topics, such as data integration, data visualization, and working with domain 
experts, often in an iterative feedback loop (see e.g., [BS17]). The difference between these areas is 
often just one of terminology.'” 


12. See https: //developers.google.com/machine-learning/glossary/ for a useful “ML glossary”. 
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ML is also very closely related to the field of statistics. Indeed, Jerry Friedman, a famous statistics 
professor at Stanford, said! 


[If the statistics field had] incorporated computing methodology from its inception as a 
fundamental tool, as opposed to simply a convenient way to apply our existing tools, many of 
the other data related fields [such as ML] would not have needed to exist — they would have 
been part of statistics. — Jerry Friedman [Fri97b] 


Machine learning is also related to artificial intelligence (AI). Historically, the field of AI 
assumed that we could program “intelligence” by hand (see e.g., [RN10; PM17]), but this approach 
has largely failed to live up to expectations, mostly because it proved to be too hard to explicitly 
encode all the knowledge such systems need. Consequently, there is renewed interest in using ML to 
help an AI system acquire its own knowledge. (Indeed the connections are so close that sometimes 
the terms “ML” and “AI” are used interchangeably, although this is arguably misleading [Pre21].) 


1.6.2 Structure of the book 


We have seen that ML is closely related to many other subjects in mathematics, statistics, computer 
science, etc. It can be hard to know where to start. 

In this book, we take one particular path through this interconnected landscape, using probability 
theory as our unifying lens. We cover statistical foundations in Part I, supervised learning in 
Part Il-Part IV, and unsupervised learning in Part V. For more information on these (and other) 
topics, please see the sequel to this book, [Mur23], 

In addition to the book, you may find the online Python notebooks that accompany this book 
helpful. See probml. github.io/book1 for details. 


1.6.3 Caveats 


In this book, we will see how machine learning can be used to create systems that can (attempt 
to) predict outputs given inputs. These predictions can then be used to choose actions so as to 
minimize expected loss. When designing such systems, it can be hard to design a loss function that 
correctly specifies all of our preferences; this can result in “reward hacking” in which the machine 
optimizes the reward function we give it, but then we realize that the function did not capture various 
constraints or preferences that we forgot to specify [Wei76; Amo+16; D’A+20]. (This is particularly 
important when tradeoffs need to be made between multiple objectives.) 

Reward hacking is an example of a larger problem known as the “alignment problem” [Chr20], 
which refers to the potential discrepancy between what we ask our algorithms to optimize and what 
we actually want them to do for us; this has raised various concerns in the context of AI ethics 
and AI safety (see e.g., [KR19; Lia20; Spe+22]). Russell [Rus19] proposes to solve this problem 
by not explicitly specifying a reward function, but instead forcing the machine to infer the reward 
by observing human behavior, an approach known as inverse reinforcement learning. However, 
emulating current or past human behavior too closely may be undesirable, and can be biased by the 
data that is available for training (see e.g., [Pau+20]). 

The above view of AI, in which an “intelligent” system makes decisions on its own, without a 
human in the loop, is believed by many to be the path towards “artificial general intelligence” 


13. Quoted in https: //brenocon. com/blog/2008/12/statistics-vs-machine-learning-fight/ 
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or AGI. An alternative approach is to view AI as “augmented intelligence” (sometimes called 
intelligence augmentation or IA). In this paradigm, AI is a process for creating “smart tools”, 
like adaptive cruise control or auto-complete in search engines; such tools maintain a human in the 
decision-making loop. In this framing, systems which have AI/ML components in them are not that 
different from other complex, semi-autonomous human artefacts, such as aeroplanes with autopilot, 
online trading platforms or medical diagnostic systems (c.f. [Jor19; Ace]). Of course, as the AI tools 
become more powerful, they can end up doing more and more on their own, making this approach 
similar to AGI. However, in augmented intelligence, the goal is not to emulate or exceed human 
behavior at certain tasks, but instead to help humans get stuff done more easily; this is how we treat 
most other technologies [Kap16]. 
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Foundations 


2 Probability: Univariate Models 


2.1 Introduction 


In this chapter, we give a brief introduction to the basics of probability theory. There are many good 
books that go into more detail, e.g., [GS97; BT08; Cha21]. 


2.1.1 What is probability? 


Probability theory is nothing but common sense reduced to calculation. — Pierre Laplace, 
1812 


We are all comfortable saying that the probability that a (fair) coin will land heads is 50%. But 
what does this mean? There are actually two different interpretations of probability. One is called 
the frequentist interpretation. In this view, probabilities represent long run frequencies of events 
that can happen multiple times. For example, the above statement means that, if we flip the coin 
many times, we expect it to land heads about half the time.! 

The other interpretation is called the Bayesian interpretation of probability. In this view, proba- 
bility is used to quantify our uncertainty or ignorance about something; hence it is fundamentally 
related to information rather than repeated trials [Jay03; Lin06]. In the Bayesian view, the above 
statement means we believe the coin is equally likely to land heads or tails on the next toss. 

One big advantage of the Bayesian interpretation is that it can be used to model our uncertainty 
about one-off events that do not have long term frequencies. For example, we might want to compute 
the probability that the polar ice cap will melt by 2030 CE. This event will happen zero or one times, 
but cannot happen repeatedly. Nevertheless, we ought to be able to quantify our uncertainty about 
this event; based on how probable we think this event is, we can decide how to take the optimal 
action, as discussed in Chapter 5. We shall therefore adopt the Bayesian interpretation in this book. 
Fortunately, the basic rules of probability theory are the same, no matter which interpretation is 
adopted. 


2.1.2 Types of uncertainty 


The uncertainty in our predictions can arise for two fundamentally different reasons. The first is 
due to our ignorance of the underlying hidden causes or mechanism generating our data. This is 


1. Actually, the Stanford statistician (and former professional magician) Persi Diaconis has shown that a coin is about 
51% likely to land facing the same way up as it started, due to the physics of the problem [DHM07]. 
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called epistemic uncertainty, since epistemology is the philosophical term used to describe the 
study of knowledge. However, a simpler term for this is model uncertainty. The second kind of 
uncertainty arises from intrinsic variability, which cannot be reduced even if we collect more data. 
This is sometimes called aleatoric uncertainty [Hac75; KD09], derived from the Latin word for 
“dice”, although a simpler term would be data uncertainty. As a concrete example, consider tossing 
a fair coin. We might know for sure that the probability of heads is p = 0.5, so there is no epistemic 
uncertainty, but we still cannot perfectly predict the outcome. 

This distinction can be important for applications such as active learning. A typical strategy is to 
query examples for which H(p(y|x,D)) is large (where H(p) is the entropy, discussed in Section 6.1). 
However, this could be due to uncertainty about the parameters, i.e., large H(p(@|D)), or just due to 
inherent variability of the outcome, corresponding to large entropy of p(y|x, 0). In the latter case, 
there would not be much use collecting more samples, since our uncertainty would not be reduced. 
See [Osb16] for further discussion of this point. 


2.1.3 Probability as an extension of logic 

In this section, we review the basic rules of probability, following the presentation of [Jay03], in which 
we view probability as an extension of Boolean logic. 

2.1.3.1 Probability of an event 


We define an event, denoted by the binary variable A, as some state of the world that either holds 
or does not hold. For example, A might be event “it will rain tomorrow”, or “it rained yesterday”, or 
“the label is y = 1”, or “the parameter 0 is between 1.5 and 2.0”, etc. The expression Pr(A) denotes 
the probability with which you believe event A is true (or the long run fraction of times that A will 
occur). We require that 0 < Pr(A) < 1, where Pr(A) = 0 means the event definitely will not happen, 


and Pr(A) = 1 means the event definitely will happen. We write Pr(A) to denote the probability of 


event A not happening; this is defined to be Pr(A) = 1 — Pr(A). 


2.1.3.2 Probability of a conjunction of two events 


We denote the joint probability of events A and B both happening as follows: 

Pr(A A B) = Pr(A, B) (2.1) 
If A and B are independent events, we have 

Pr(A, B) = Pr(A) Pr(B) (2.2) 
For example, suppose X and Y are chosen uniformly at random from the set ¥ = {1,2,3,4}. Let 
A be the event that X € {1,2}, and B be the event that Y € {3}. Then we have Pr(A, B) = 
Pr(A) Pr(B) = 4- 4. 
2.1.3.3 Probability of a union of two events 


The probability of event A or B happening is given by 
Pr(AV B) = Pr(A) + Pr(B) — Pr(A A B) (2.3) 
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If the events are mutually exclusive (so they cannot happen at the same time), we get 
Pr(A V B) = Pr(A) + Pr(B) (2.4) 


For example, suppose X is chosen uniformly at random from the set ¥ = {1,2,3,4}. Let A be the 
event that X € {1,2} and B be the event that X € {3}. Then we have Pr(A V B) = 2 + 4. 


2.1.3.4 Conditional probability of one event given another 


We define the conditional probability of event B happening given that A has occurred as follows: 


Pr(B|A) £ — (2.5) 


This is not defined if Pr(A) = 0, since we cannot condition on an impossible event. 


2.1.3.5 Independence of events 


We say that event A is independent of event B if 


Pr(A, B) = Pr(A) Pr(B) (2.6) 


2.1.3.6 Conditional independence of events 


We say that events A and B are conditionally independent given event C if 
Pr(A, B|C) = Pr(A|C) Pr(B|C) (2.7) 


This is written as A L B\C. Events are often dependent on each other, but may be rendered 
independent if we condition on the relevant intermediate variables, as we discuss in more detail later 
in this chapter. 


2.2 Random variables 


Suppose X represents some unknown quantity of interest, such as which way a dice will land when 
we roll it, or the temperature outside your house at the current time. If the value of X is unknown 
and/or could change, we call it a random variable or rv. The set of possible values, denoted 1, is 
known as the sample space or state space. An event is a set of outcomes from a given sample 
space. For example, if X represents the face of a dice that is rolled, so ¥ = {1,2,...,6}, the event 
of “seeing a 1” is denoted X = 1, the event of “seeing an odd number” is denoted X € {1,3,5}, the 
event of “seeing a number between 1 and 3” is denoted 1 < X < 3, etc. 


2.2.1 Discrete random variables 


If the sample space ¥ is finite or countably infinite, then X is called a discrete random variable. 
In this case, we denote the probability of the event that X has value x by Pr(X = x). We define the 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


36 Chapter 2. Probability: Univariate Models 


0.25 0.25 
0.00 Ea E E o| 0.00 
į 2 3 4 


(a) (b) 


Figure 2.1: Some discrete distributions on the state space X = {1,2,3,4}. (a) A uniform distribution with 
p(x =k) = 1/4. (b) A degenerate distribution (delta function) that puts all its mass on x = 1. Generated by 
discrete_prob_ dist_ plot.ipynb. 


probability mass function or pmf as a function which computes the probability of events which 
correspond to setting the rv to each possible value: 


p(x) £ Pr(X = 2) (2.8) 


The pmf satisfies the properties 0 < p(x) < 1 and SO, ex p(x) = 1. 

If X has a finite number of values, say K, the pmf can be represented as a list of K numbers, which 
we can plot as a histogram. For example, Figure 2.1 shows two pmf’s defined on ¥ = {1,2,3,4}. 
On the left we have a uniform distribution, p(x) = 1/4, and on the right, we have a degenerate 
distribution, p(x) = I (x = 1), where I() is the binary indicator function. Thus the distribution in 
Figure 2.1(b) represents the fact that X is always equal to the value 1. (Thus we see that random 
variables can also be constant.) 


2.2.2 Continuous random variables 


If X € R is a real-valued quantity, it is called a continuous random variable. In this case, we can 
no longer create a finite (or countable) set of distinct possible values it can take on. However, there 
are a countable number of intervals which we can partition the real line into. If we associate events 
with X being in each one of these intervals, we can use the methods discussed above for discrete 
random variables. Informally speaking, we can represent the probability of X taking on a specific 
real value by allowing the size of the intervals to shrink to zero, as we show below. 


2.2.2.1 Cumulative distribution function (cdf) 


Define the events A = (X < a), B = (X < b) and C = (a < X < b), where a < b. We have that 
B= AVC, and since A and C are mutually exclusive, the sum rules gives 


Pr(B) = Pr(A) + Pr(C) (2.9) 
and hence the probability of being in interval C is given by 
Pr(C) = Pr(B) — Pr(A) (2.10) 
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Figure 2.2: (a) Plot of the cdf for the standard normal, N (0,1). Generated by gauss_plot.ipynb. (b) 
Corresponding pdf. The shaded regions each contain a/2 of the probability mass. Therefore the nonshaded 
region contains 1 — a of the probability mass. The leftmost cutoff point is ®~'(a/2), where ® is the cdf 


of the Gaussian. By symmetry, the rightmost cutoff point is ®-1(1 — a/2) = —®-1(a/2). Generated by 
quantile  plot.ipynb. 


In general, we define the cumulative distribution function or cdf of the rv X as follows: 

Pla) = Pr(X < x) (2.11) 
(Note that we use a capital P to represent the cdf.) Using this, we can compute the probability of 
being in any interval as follows: 

Pr(a < X < b) = P(b) — P(a) (2.12) 


Cdf’s are monotonically non-decreasing functions. See Figure 2.2a for an example, where we 
illustrate the cdf of a standard normal distribution, M (x|0, 1); see Section 2.6 for details. 


2.2.2.2 Probability density function (pdf) 


We define the probability density function or pdf as the derivative of the cdf: 


p(x) ê £ pla) (2.13) 


(Note that this derivative does not always exist, in which case the pdf is not defined.) See Figure 2.2b 
for an example, where we illustrate the pdf of a univariate Gaussian (see Section 2.6 for details). 

Given a pdf, we can compute the probability of a continuous variable being in a finite interval as 
follows: 


Pr(a< X <b)= [ roa = P(b) — P(a) (2.14) 


a 


As the size of the interval gets smaller, we can write 
Pr(a < X < x + dz) & p(x)dz (2.15) 


Intuitively, this says the probability of X being in a small interval around z is the density at x times 
the width of the interval. 
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2.2.2.3 Quantiles 


If the cdf P is strictly monotonically increasing, it has an inverse, called the inverse cdf, or percent 
point function (ppf), or quantile function. 

If P is the cdf of X, then P~'(q) is the value x, such that Pr(X < z4) = q; this is called the q’th 
quantile of P. The value P~1(0.5) is the median of the distribution, with half of the probability 
mass on the left, and half on the right. The values P~1(0.25) and P~1!(0.75) are the lower and upper 
quartiles. 

For example, let ® be the cdf of the Gaussian distribution V(0,1), and ®~! be the inverse cdf. 
Then points to the left of 6~'(a@/2) contain a/2 of the probability mass, as illustrated in Figure 2.2b. 
By symmetry, points to the right of ®~'(1 — a/2) also contain a/2 of the mass. Hence the central 
interval (6~1(a/2), @-1(1 — a/2)) contains 1 — a of the mass. If we set a = 0.05, the central 95% 
interval is covered by the range 


(~* (0.025), 6~*(0.975)) = (—1.96, 1.96) (2.16) 


If the distribution is N (u, 07), then the 95% interval becomes (u — 1.960, u + 1.960). This is often 
approximated by writing u + 20. 


2.2.3 Sets of related random variables 


In this section, we discuss distributions over sets of related random variables. 

Suppose, to start, that we have two random variables, X and Y. We can define the joint 
distribution of two random variables using p(x, y) = p(X = x,Y = y) for all possible values of 
X and Y. If both variables have finite cardinality, we can represent the joint distribution as a 2d 
table, all of whose entries sum to one. For example, consider the following example with two binary 
variables: 


p(x, Y) | Y=0 Y=1 
xX =0 0.2 0.3 
X=1 0.3 0.2 
If two variables are independent, we can represent the joint as the product of the two marginals. If 
both variables have finite cardinality, we can factorize the 2d joint table into a product of two 1d 
vectors, as shown in Figure 2.3. 
Given a joint distribution, we define the marginal distribution of an rv as follows: 


p(X = 2) =o p(X =2,Y =y) (2.17) 


where we are summing over all possible states of Y. This is sometimes called the sum rule or the 

rule of total probability. We define p(Y = y) similarly. For example, from the above 2d table, we 

see p(X = 0) = 0.2 + 0.3 = 0.5 and p(Y = 0) = 0.2 + 0.3 = 0.5. (The term “marginal” comes from 

the accounting practice of writing the sums of rows and columns on the side, or margin, of a table.) 
We define the conditional distribution of an rv using 


p(X =2,Y =y) 


PY =y|X =a) = 2.18 
Y =x =2) =" (2.18) 
We can rearrange this equation to get 

p(x, y) = p(x)p(y|) (2.19) 
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P(X, Y) POY) 


P(X) 


Figure 2.8: Computing p(x, y) = p(x)p(y), where X L Y. Here X and Y are discrete random variables; X 
has 6 possible states (values) and Y has 5 possible states. A general joint distribution on two such variables 
would require (6 x 5) — 1 = 29 parameters to define it (we subtract 1 because of the sum-to-one constraint). 
By assuming (unconditional) independence, we only need (6 — 1) + (5 — 1) = 9 parameters to define p(x, y). 


This is called the product rule. 
By extending the product rule to D variables, we get the chain rule of probability: 


P(@1:D) = p(z1)p(x2|x1)p(£3|£1, 22) p(wa|@1, 2,23)... p(@p|#1:d-1) (2.20) 
This provides a way to create a high dimensional joint distribution from a set of conditional 
distributions. We discuss this in more detail in Section 3.6. 

2.2.4 Independence and conditional independence 


We say X and Y are unconditionally independent or marginally independent, denoted X L Y, 
if we can represent the joint as the product of the two marginals (see Figure 2.3), i.e., 


X LY 4> p(X, Y) =p(X)p(Y) (2.21) 


In general, we say a set of variables X1,..., Xn is (mutually) independent if the joint can be written 
as a product of marginals for all subsets {X1,..., Xm} C {X1,..., Xn}: ie, 


1G cree oy = Je) (2.22) 


For example, we say X1, X2, X3 are mutually independent if the following conditions hold: p(X1, X2, X3) = 
p(X1)p(X2)p(Xs), p(X1, X2) = p(X1)p(X2), p(X2, X3) = p(X2)p(X3), and p(X1, X3) = p(X1)p(X3).7 

Unfortunately, unconditional independence is rare, because most variables can influence most other 
variables. However, usually this influence is mediated via other variables rather than being direct. 
We therefore say X and Y are conditionally independent (CI) given Z iff the conditional joint 
can be written as a product of conditional marginals: 


XLY|Z 4> p(X,Y|Z) =p(X|Z)p(Y|Z) (2.23) 


2. For further discussion, see https://github.com/probml/pm1-book/issues/353#issuecomment - 1120327442. 
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We can write this assumption as a graph X — Z — Y, which captures the intuition that all the 
dependencies between X and Y are mediated via Z. By using larger graphs, we can define complex 
joint distributions; these are known as graphical models, and are discussed in Section 3.6. 

2.2.5 Moments of a distribution 

In this section, we describe various summary statistics that can be derived from a probability 
distribution (either a pdf or pmf). 

2.2.5.1 Mean of a distribution 


The most familiar property of a distribution is its mean, or expected value, often denoted by u. 
For continuous rv’s, the mean is defined as follows: 


z [X] £ feroa (2.24) 


If the integral is not finite, the mean is not defined; we will see some examples of this later. 
For discrete rv’s, the mean is defined as follows: 


[X] £ N° x p(x) (2.25) 


LEX 


However, this is only meaningful if the values of x are ordered in some way (e.g., if they represent 
integer counts). 
Since the mean is a linear operator, we have 


` [aX +b] = aE [X] +b (2.26) 


This is called the linearity of expectation. 
For a set of n random variables, one can show that the expectation of their sum is as follows: 


n 


[Zx =) E[X] (2.27) 


i=1 


If they are independent, the expectation of their product is given by 


n 


TL x] = | [Ex] (2.28) 


i=1 
2.2.5.2 Variance of a distribution 


The variance is a measure of the “spread” of a distribution, often denoted by ø?. This is defined as 
follows: 


VIX] SE [XW] = f(e- W)?r(o)ae (2.29) 


= / ap(x)da + u? J p(a)dx — 2u / ap(a)da =E[X?] — p? (2.30) 
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from which we derive the useful result 


6 [X?7] = 0 + p? (2.31) 


The standard deviation is defined as 
std [X] = VY [X] =o (2.32) 


This is useful since it has the same units as X itself. 
The variance of a shifted and scaled version of a random variable is given by 


V [aX +b] = a?V [X] (2.33) 


If we have a set of n independent random variables, the variance of their sum is given by the sum 
of their variances: 


> xd = sv [Xi] (2.34) 
The ne of ek product can also be derived, as follows: 

fix = | C1" — (E rx (2.35) 
=E TI xal - (J [E xX? (2.36) 
= [ [E [x}] - IK z [X]? (2.37) 
= IV+ : [X:)’) — IK [X (2.38) 
= [toi + i) -][ (2.39) 


Yy 


Yy 


2.2.5.3 Mode of a distribution 
The mode of a distribution is the value with the highest probability mass or probability density: 
x* = argmax p(x) (2.40) 
x 
If the distribution is multimodal, this may not be unique, as illustrated in Figure 2.4. Furthermore, 
even if there is a unique mode, this point may not be a good summary of the distribution. 
2.2.5.4 Conditional moments 


When we have two or more dependent random variables, we can compute the moments of one given 
knowledge of the other. For example, the law of iterated expectations, also called the law of 
total expectation, tells us that 


2 [X] = Ey [E[X|Y]] (2.41) 
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Figure 2.4: Illustration of a mixture of two 1d Gaussians, p(x) = 0.5N (x|0, 0.5) + 0.5M (2/2, 0.5). Generated 
by bimodal_ dist_ plot.ipynb. 


To prove this, let us suppose, for simplicity, that X and Y are both discrete rv’s. Then we have 


by (E[X|Y]] = Ey ` p(X =alY) (2.42) 


-E [Exar =n 


Yy 


pY =y) =>) ap(X =2,Y =y)=E[X] (2-43) 


To give a more intuitive explanation, consider the following simple example.’ Let X be the 
lifetime duration of a lightbulb, and let Y be the factory the lightbulb was produced in. Suppose 
2 [X|Y = 1] = 5000 and E [X|Y = 2] = 4000, indicating that factory 1 produces longer lasting bulbs. 
Suppose factory 1 supplies 60% of the lightbulbs, so p(Y = 1) = 0.6 and p(Y = 2) = 0.4. Then the 
expected duration of a random lightbulb is given by 


[X] =E[X|Y = 1]p(Y =1) + E[X|Y = 2] p(Y = 2) = 5000 x 0.6 + 4000 x 0.4 = 4600 (2.44) 


There is a similar formula for the variance. In particular, the law of total variance, also called 
the conditional variance formula, tells us that 


V[X] = Ey [V [XIV] + Vy [E [XY] (2.45) 


To see this, let us define the conditional moments, x;y = E[X|Y], sxy =  [XP]Y], and 
oxy = V[X|Y] = sx — My which are functions of Y (and therefore are random quantities). 
Then we have 


V [X] = E [X?] - (E[X])? = Ey [sxy] — (Ey lux] (2.46) 
= Ey [ož] +Ey |e% r] — (Ey [exiy])° (2.47) 
= Ey [V[X|Y]] + Vy[exyy] (2.48) 


To get some intuition for these formulas, consider a mixture of K univariate Gaussians. Let 
Y be the hidden indicator variable that specifies which mixture component we are using, and let 


3. This example is from https: //en.wikipedia. org/wiki/Law_of_total_expectation, but with modified notation. 
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Dataset: I Dataset: II Dataset: III Dataset: IV 


0 10 20 


(a) (b) (c) (d) 


Figure 2.5: Illustration of Anscombe’s quartet. All of these datasets have the same low order summary 
statistics. Generated by anscombes_ quartet.ipynb. 


X= Dew, TyN (X|py, oy). In Figure 2.4, we have mı = T2 = 0.5, uy = 0, u2 = 2, 01 = 02 = 0.5. 
Thus 


a [V [X|Y]] = mo? + T202 = 0.25 (2.49) 
V [E [X|Y]] = mı (mı — E)? + T2(u2 — B)? = 0.5(0 — 1)? + 0.5(2 — 1)? =0.54+05=1 (2.50) 


So we get the intuitive result that the variance of X is dominated by which centroid it is drawn from 
(i.e., difference in the means), rather than the local variance around each centroid. 


2.2.6 Limitations of summary statistics * 


Although it is common to summarize a probability distribution (or points sampled from a distribution) 
using simple statistics such as the mean and variance, this can lose a lot of information. A striking 
example of this is known as Anscombe’s quartet [Ans73], which is illustrated in Figure 2.5. This 
shows 4 different datasets of (x, y) pairs, all of which have identical mean, variance and correlation 
coefficient p (defined in Section 3.1.2): E [x] = 9, V [z] = 11, E [y] = 7.50, V [y] = 4.12, and p = 0.816.4 
However, the joint distributions p(x, y) from which these points were sampled are clearly very different. 
Anscombe invented these datasets, each consisting of 10 data points, to counter the impression among 
statisticians that numerical summaries are superior to data visualization [Ans73]. 

An even more striking example of this phenomenon is shown in Figure 2.6. This consists of a 
dataset that looks like a dinosaur”, plus 11 other datasets, all of which have identical low order 
statistics. This collection of datasets is called the Datasaurus Dozen [MF 17]. The exact values of 
the (x,y) points are available online. They were computed using simulated annealing, a derivative 
free optimization method which we discuss in the sequel to this book, [Mur23]. (The objective 


4. The maximum likelihood estimate for the variance in Equation (4.36) differs from the unbiased estimate in 
Equation (4.38). For the former, we have Y [x] = 10.00, V [y] = 3.75, for the latter, we have V [x] = 11.00, V [y] = 4.12. 
5. This dataset was created by Alberto Cairo, and is available at http://www.thefunctionalart.com/2016/08/ 
download-datasaurus-never-trust-summary.html 

6. https://www.autodesk.com/research/publications/same-stats-different-graphs. There are actually 13 
datasets in total, including the dinosaur. We omitted the “away” dataset for visual clarity. 
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Figure 2.6: Illustration of the Datasaurus Dozen. All of these datasets have the same low order summary 
statistics. Adapted from Figure 1 of [MF 17]. Generated by datasaurus_ dozen.ipynb. 


function being optimized measures deviation from the target summary statistics of the original 
dinosaur, plus distance from a particular target shape.) 

The same simulated annealing approach can be applied to 1d datasets, as shown in Figure 2.7. We 
see that all the datasets are quite different, but they all have the same median and inter-quartile 
range as shown by the central shaded part of the box plots in the middle. A better visualization 
is known as a violin plot, shown on the right. This shows (two copies of) the 1d kernel density 
estimate (Section 16.3) of the distribution on the vertical axis, in addition to the median and IQR 
markers. This visualization is better able to distinguish differences in the distributions. However, the 
technique is limited to 1d data. 


2.3 Bayes’ rule 


Bayes’s theorem is to the theory of probability what Pythagoras’s theorem is to geometry. 
— Sir Harold Jeffreys, 1973 [Jef73]. 


In this section, we discuss the basics of Bayesian inference. According to the Merriam-Webster 


dictionary, the term “inference” means “the act of passing from sample data to generalizations, usually 
with calculated degrees of certainty”. The term “Bayesian” is used to refer to inference methods 
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Figure 2.7: Illustration of 7 different datasets (left), the corresponding box plots (middle) and 
violin box plots (right). From Figure 8 of https: //www. autodesk. com/research/publications/ 
same-stats-different-graphs. Used with kind permission of Justin Matejka. 


that represent “degrees of certainty” using probability theory, and which leverage Bayes’ rule’, to 
update the degree of certainty given data. 

Bayes’ rule itself is very simple: it is just a formula for computing the probability distribution over 
possible values of an unknown (or hidden) quantity H given some observed data Y = y: 


_ p(H =h)p(Y =y|H =h) 


pH =h|Y =y) = 2.51 
(H = HY =y) ae (2.51) 
This follows automatically from the identity 

plhly)p(y) = p(h)p(ylh) = p(h, y) (2.52) 


which itself follows from the product rule of probability. 

In Equation (2.51), the term p(H) represents what we know about possible values of H before 
we see any data; this is called the prior distribution. (If H has K possible values, then p(H) is 
a vector of K probabilities, that sum to 1.) The term p(Y|H = h) represents the distribution over 
the possible outcomes Y we expect to see if H = h; this is called the observation distribution. 
When we evaluate this at a point corresponding to the actual observations, y, we get the function 
p(Y = y|H = h), which is called the likelihood. (Note that this is a function of h, since y is 
fixed, but it is not a probability distribution, since it does not sum to one.) Multiplying the prior 
distribution p(H = h) by the likelihood function p(Y = y|H = h) for each h gives the unnormalized 
joint distribution p(H = h, Y = y). We can convert this into a normalized distribution by dividing 
by p(Y = y), which is known as the marginal likelihood, since it is computed by marginalizing 
over the unknown H: 


pY =y)= >) p(H =k')p(Y =y|H =h') = $ p(H=k,Y =y) (2.53) 
h'CH h'ECH 


7. Thomas Bayes (1702-1761) was an English mathematician and Presbyterian minister. For a discussion of whether 
to spell this as Bayes rule, Bayes’ rule or Bayes’s rule, see https: //bit.ly/2kDtLuK. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


46 Chapter 2. Probability: Univariate Models 


Observation 
0 1 
0 | TNR=Specificity=0.975 FPR=1-TNR=0.025 
i FNR=1-TPR=0.125 TPR=Sensitivity=0.875 


Truth 


Table 2.1: Likelihood function p(Y |H) for a binary observation Y given two possible hidden states H. Each 
row sums to one. Abbreviations: TNR is true negative rate, TPR is true positive rate, FNR is false negative 
rate, FPR is false positive rate. 


Normalizing the joint distribution by computing p(H = h, Y = y)/p(Y = y) for each h gives the 
posterior distribution p(H = h|Y = y); this represents our new belief state about the possible 
values of H. 

We can summarize Bayes rule in words as follows: 


posterior œ prior x likelihood (2.54) 


Here we use the symbol œ to denote “proportional to”, since we are ignoring the denominator, which is 
just a constant, independent of H. Using Bayes rule to update a distribution over unknown values of 
some quantity of interest, given relevant observed data, is called Bayesian inference, or posterior 
inference. It can also just be called probabilistic inference. 

Below we give some simple examples of Bayesian inference in action. We will see many more 
interesting examples later in this book. 


2.3.1 Example: Testing for COVID-19 


Suppose you think you may have contracted COVID-19, which is an infectious disease caused by 
the SARS-CoV-2 virus. You decide to take a diagnostic test, and you want to use its result to 
determine if you are infected or not. 

Let H = 1 be the event that you are infected, and H = 0 be the event you are not infected. Let 
Y = 1 if the test is positive, and Y = 0 if the test is negative. We want to compute p(H = h|Y = y), 
for h € {0,1}, where y is the observed test outcome. (We will write the distribution of values, 
[p(H = 0|Y = y), p(H = 1|Y = y)] as p(Aly), for brevity.) We can think of this as a form of binary 
classification, where H is the unknown class label, and y is the feature vector. 

First we must specify the likelihood. This quantity obviously depends on how reliable the 
test is. There are two key parameters. The sensitivity (aka true positive rate) is defined as 
p(Y = 1|H = 1), i.e., the probability of a positive test given that the truth is positive. The false 
negative rate is defined as one minus the sensitivity. The specificity (aka true negative rate) 
is defined as p(Y = 0|H = 0), i.e., the probability of a negative test given that the truth is negative. 
The false positive rate is defined as one minus the specificity. We summarize all these quantities 
in Table 2.1. (See Section 5.1.3.1 for more details.) Following https: //nyti.ms/31MTZgV, we set 
the sensitivity to 87.5% and the specificity to 97.5%. 

Next we must specify the prior. The quantity p(H = 1) represents the prevalence of the 
disease in the area in which you live. We set this to p(H = 1) = 0.1 (i.e., 10%), which was the 
prevalence in New York City in Spring 2020. (This example was chosen to match the numbers in 
https://nyti.ms/31MTZegV.) 
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Now suppose you test positive. We have 


p(Y =1|H = 1)p(H = 1) 


p(H =1|Y = 1) = 2.55 
(H= 1Y = D = FEA Np = FY = A = 0p = 0) a 
TPR x prior 
= 2, 
TPR x prior + FPR x (1 — prior) (2:50) 
0.875 x 0.1 
= = 0.795 2.57 
0.875 x 0.1 + 0.025 x 0.9 ea 
So there is a 79.5% chance you are infected. 
Now suppose you test negative. The probability you are infected is given by 
p(Y = 0|H = 1)p(H = 1) 
p(H = 1|Y =0) = 2.58 
(H= 1Y = 0) = FEA = Np = 1) FY = 0H = 0H =0) an 
_ paai x prior (2.59) 
FNR x prior + TNR x (1 — prior) 
0.125 x 0.1 
x = 0.014 (2.60) 


~ 0.125 x 0.1 + 0.975 x 0.9 


So there is just a 1.4% chance you are infected. 

Nowadays COVID-19 prevalence is much lower. Suppose we repeat these calculations using a base 
rate of 1%; now the posteriors reduce to 26% and 0.13% respectively. 

The fact that you only have a 26% chance of being infected with COVID-19, even after a positive 
test, is very counter-intuitive. The reason is that a single positive test is more likely to be a false 
positive than due to the disease, since the disease is rare. To see this, suppose we have a population 
of 100,000 people, of whom 1000 are infected. Of those who are infected, 875 = 0.875 x 1000 test 
positive, and of those who are uninfected, 2475 = 0.025 x 99, 000 test positive. Thus the total number 
of positives is 3350 = 875 + 2475, so the posterior probability of being infected given a positive test 
is 875/3350 = 0.26. 

Of course, the above calculations assume we know the sensitivity and specificity of the test. See 
[GC20] for how to apply Bayes rule for diagnostic testing when there is uncertainty about these 
parameters. 


2.3.2 Example: The Monty Hall problem 


In this section, we consider a more “frivolous” application of Bayes rule. In particular, we apply it to 
the famous Monty Hall problem. 

Imagine a game show with the following rules: There are three doors, labeled 1, 2, 3. A single prize 
(e.g., a car) has been hidden behind one of them. You get to select one door. Then the gameshow 
host opens one of the other two doors (not the one you picked), in such a way as to not reveal the 
prize location. At this point, you will be given a fresh choice of door: you can either stick with your 
first choice, or you can switch to the other closed door. All the doors will then be opened and you 
will receive whatever is behind your final choice of door. 

For example, suppose you choose door 1, and the gameshow host opens door 3, revealing nothing 
behind the door, as promised. Should you (a) stick with door 1, or (b) switch to door 2, or (c) does 
it make no difference? 
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Door 1 Door2 Door3 Switch Stay 
Car - - Lose Win 
- Car - Win Lose 
- - Car Win Lose 


Table 2.2: 3 possible states for the Monty Hall game, showing that switching doors is two times better (on 
average) than staying with your original choice. Adapted from Table 6.1 of [PM18]. 


Intuitively, it seems it should make no difference, since your initial choice of door cannot influence 
the location of the prize. However, the fact that the host opened door 3 tells us something about the 
location of the prize, since he made his choice conditioned on the knowledge of the true location and 
on your choice. As we show below, you are in fact twice as likely to win the prize if you switch to 
door 2. 

To show this, we will use Bayes’ rule. Let H; denote the hypothesis that the prize is behind door i. 

We make the following assumptions: the three hypotheses Hı, Hə and H3 are equiprobable a priori, 
i.e., 
1 
3° 
The datum we receive, after choosing door 1, is either Y = 3 and Y = 2 (meaning door 3 or 2 is 
opened, respectively). We assume that these two possible outcomes have the following probabilities. 
If the prize is behind door 1, then the host selects at random between Y = 2 and Y = 3. Otherwise 
the choice of the host is forced and the probabilities are 0 and 1. 


P(A) = P(H2) = P(H3) = (2.61) 


P(Y =2|H,)= P(Y = 2|H2)=0 | P(Y =2|Hz)=1 (2.62) 
P(Y =3|H,)=2 | P(Y = 3|H2)=1 | P(Y =3|H3)=0 
Now, using Bayes’ Pa we evaluate the posterior probabilities of the hypotheses: 
P(Y = 3|H;:)P(H:;) 
P(H,|Y =3) = 2.63 
— 3) — 1/20/3) (1)(4/3) (0)(1/3) 
| PURIY = 3)= Se | PRY = 3) = $e ress) a P(H3|Y = 3) = PU (2.64) 
The denominator P(Y = 3) is P(Y = 3) =%+4=3. So 
1 2 
| PY = =: 5 | P(H2|Y = 3) 2p P(H3|Y =3) = 0. | (2.65) 


So the contestant should switch to door 2 in order to have the biggest chance of getting the prize. 
See Table 2.2 for a worked example. 

Many people find this outcome surprising. One way to make it more intuitive is to perform a 
thought experiment in which the game is played with a million doors. The rules are now that the 
contestant chooses one door, then the game show host opens 999,998 doors in such a way as not to 
reveal the prize, leaving the contestant’s selected door and one other door closed. The contestant 
may now stick or switch. Imagine the contestant confronted by a million doors, of which doors 1 and 
234,598 have not been opened, door 1 having been the contestant’s initial guess. Where do you think 
the prize is? 
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Figure 2.8: Any planar line-drawing is geometrically consistent with infinitely many 8-D structures. From 
Figure 11 of [SA93]. Used with kind permission of Pawan Sinha. 


2.3.3 Inverse problems * 


Probability theory is concerned with predicting a distribution over outcomes y given knowledge (or 
assumptions) about the state of the world, h. By contrast, inverse probability is concerned with 
inferring the state of the world from observations of outcomes. We can think of this as inverting the 
h —> y mapping. 

For example, consider trying to infer a 3d shape h from a 2d image y, which is a classic problem 
in visual scene understanding. Unfortunately, this is a fundamentally ill-posed problem, as 
illustrated in Figure 2.8, since there are multiple possible hidden h’s consistent with the same observed 
y (see e.g., [Piz01]). Similarly, we can view natural language understanding as an ill-posed 
problem, in which the listener must infer the intention h from the (often ambiguous) words spoken 
by the speaker (see e.g., [Sab21]). 

To tackle such inverse problems, we can use Bayes’ rule to compute the posterior, p(h|y), which 
gives a distribution over possible states of the world. This requires specifying the forwards model, 
p(y|h), as well as a prior p(h), which can be used to rule out (or downweight) implausible world 
states. We discuss this topic in more detail in the sequel to this book, [Mur23]. 


2.4 Bernoulli and binomial distributions 
Perhaps the simplest probability distribution is the Bernoulli distribution, which can be used to 


model binary events, as we discuss below. 


2.4.1 Definition 


Consider tossing a coin, where the probability of event that it lands heads is given by 0 < 0 < 1. 
Let Y = 1 denote this event, and let Y = 0 denote the event that the coin lands tails. Thus we are 
assuming that p(Y = 1) = 6 and p(Y = 0) = 1 — 0. This is called the Bernoulli distribution, and 
can be written as follows 


Y ~ Ber(0) (2.66) 
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6=0.250 6=0.900 


(b) 


Figure 2.9: Illustration of the binomial distribution with N = 10 and (a) 0 = 0.25 and (b) 0 = 0.9. Generated 
by binom_ dist_ plot. ipynb. 


where the symbol ~ means “is sampled from” or “is distributed as”, and Ber refers to Bernoulli. The 
probability mass function (pmf) of this distribution is defined as follows: 


1-6 ify=0 


2.67 
0 ify=1 eeu 


Ber(y|@) = 


(See Section 2.2.1 for details on pmf’s.) We can write this in a more concise manner as follows: 
Ber(y|@) 0” (1 — 0)” (2.68) 
The Bernoulli distribution is a special case of the binomial distribution. To explain this, suppose 

we observe a set of N Bernoulli trials, denoted yn ~ Ber(-|@), for n = 1 : N. Concretely, think of 


tossing a coin N times. Let us define s to be the total number of heads, s £ ys I (yn = 1). The 
distribution of s is given by the binomial distribution: 


Bin(s|N,0) £ C) o(a = 0N (2.69) 
where 
N\, N! 
(i) * wom eae 


is the number of ways to choose k items from N (this is known as the binomial coefficient, and is 
pronounced “N choose k”). See Figure 2.9 for some examples of the binomial distribution. If N = 1, 
the binomial distribution reduces to the Bernoulli distribution. 

2.4.2 Sigmoid (logistic) function 


When we want to predict a binary variable y € {0,1} given some inputs x € ¥, we need to use a 
conditional probability distribution of the form 


p(yla, @) = Ber(y| f(a; @)) (2.71) 
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sigmoid function Heaviside function 


(a) (b) 


Figure 2.10: (a) The sigmoid (logistic) function o(a) = (1+e-*)~'. (b) The Heaviside function I(a > 0). 
Generated by activation_fun_ plot.ipynb. 


o(z) = 5 = =s an (2.72) 

L ola) = o(x)(1 — o(x)) 2.73) 

1 — a(x) = o(-2) (2.74) 
—1 = P ATG: 

a (p) = log G — -) logit (p) (2.75) 

o (x) = log(1 + e”) = softplus(x) (2.76) 

Lola) = ör) (2.77) 


Table 2.93: Some useful properties of the sigmoid (logistic) and related functions. Note that the logit function 
is the inverse of the sigmoid function, and has a domain of [0, 1]. 


where f(a; 0) is some function that predicts the mean parameter of the output distribution. We will 
consider many different kinds of function f in Part [[-Part IV. 

To avoid the requirement that 0 < f(x;@) < 1, we can let f be an unconstrained function, and 
use the following model: 


ply|z, 0) = Ber(y|o( f(x; 0))) (2.78) 
Here a() is the sigmoid or logistic function, defined as follows: 
1 
4 _ 2.79 
o(a) (2.79) 


where a = f(x;0). The term “sigmoid” means S-shaped: see Figure 2.10a for a plot. We see that it 
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Figure 2.11: Logistic regression applied to a 1-dimensional, 2-class version of the Iris dataset. Generated by 
iris_logreg.ipynb. Adapted from Figure 4.23 of [Gér19]. 


maps the whole real line to [0,1], which is necessary for the output to be interpreted as a probability 
(and hence a valid value for the Bernoulli parameter 0). The sigmoid function can be thought of as a 
“soft” version of the heaviside step function, defined by 


H(a) £I (a > 0) (2.80) 


as shown in Figure 2.10b. 
Plugging the definition of the sigmoid function into Equation (2.78) we get 


1 el 

p(y = 1|æ, 0) Le ca a(a) (2.81) 
1 _ e7’ B 1 

lte-¢ 1+e-4 1+€4 


The quantity a is equal to the log odds, log(z4), where p = p(y = 1|x;@). To see this, note that 


ply = Ola, 6) =1 = o(-a) (2.82) 


p e 1+e% 

l — j=] =] a) — 2. 

og (1) =s (1HE) = oee) =a (2.83) 
The logistic function or sigmoid function maps the log-odds a to p: 

1 e 
= loøisti = â = 2.84 

p = logistic(a) = o (a) mo ee ( ) 
The inverse of this is called the logit function, and maps p to the log-odds a: 

a = logit(p) = o7} (p) £ log (4) (2.85) 

=P 


See Table 2.3 for some useful properties of these functions. 


2.4.3 Binary logistic regression 


In this section, we use a conditional Bernoulli model, where we use a linear predictor of the form 
f(x;@) = w'x +b. Thus the model has the form 


p(yl@;0) = Ber(ylo(w" « + b)) (2.86) 
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In other words, 


= 1 
-© 1+ e7(wre+) 


ply = 1|æ; 0) = o(w' x + b) (2.87) 
This is called logistic regression. 

For example consider a 1-dimensional, 2-class version of the iris dataset, where the positive class is 
“Virginica” and the negative class is “not Virginica”, and the feature x we use is the petal width. We 
fit a logistic regression model to this and show the results in Figure 2.11. The decision boundary 
corresponds to the value 2* where p(y = 1|x = «*,@) = 0.5. We see that, in this example, x* ~ 1.7. 
As x moves away from this boundary, the classifier becomes more confident in its prediction about 
the class label. 

It should be clear from this example why it would be inappropriate to use linear regression for a 
(binary) classification problem. In such a model, the probabilities would increase above 1 as we move 
far enough to the right, and below 0 as we move far enough to the left. 

For more detail on logistic regression, see Chapter 10. 


2.5 Categorical and multinomial distributions 


To represent a distribution over a finite set of labels, y € {1,...,C}, we can use the categorical 
distribution, which generalizes the Bernoulli to C > 2 values. 


2.5.1 Definition 


The categorical distribution is a discrete probability distribution with one parameter per class: 
C 
Cat (yl) £ ] [ ae (2.88) 
c=1 


In other words, p(y = c|@) = 0e. Note that the parameters are constrained so that 0 < @, < 1 and 
Ti 0e = 1; thus there are only C — 1 independent parameters. 

We can write the categorical distribution in another way by converting the discrete variable y into 
a one-hot vector with C elements, all of which are 0 except for the entry corresponding to the class 
label. (The term “one-hot” arises from electrical engineering, where binary vectors are encoded as 
electrical current on a set of wires, which can be active (“hot”) or not (“cold”).) For example, if C = 3, 
we encode the classes 1, 2 and 3 as (1,0,0), (0,1,0), and (0,0,1). More generally, we can encode the 
classes using unit vectors, where e, is all 0s except for dimension c. (This is also called a dummy 
encoding.) Using one-hot encodings, we can write the categorical distribution as follows: 


C 
Cat(yl6) ê J [0 (2.89) 


c=1 


The categorical distribution is a special case of the multinomial distribution. To explain this, 
suppose we observe N categorical trials, y, ~ Cat(-|0), for n = 1: N. Concretely, think of rolling 
a C-sided dice N times. Let us define y to be a vector that counts the number of times each face 
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Figure 2.12: Softmaz distribution softmax(a/T), where a = (3,0,1), at temperatures of T = 100, T = 2 
and T = 1. When the temperature is high (left), the distribution is uniform, whereas when the temperature 
is low (right), the distribution is “spiky”, with most of its mass on the largest element. Generated by 
softmaxz_ plot. ipynb. 


shows up, i.e., Ye = Ne £ DaN I (yn = c). Now y is no longer one-hot, but is “multi-hot”, since it 
has a non-zero entry for every value of c that was observed across all N trials. The distribution of y 
is given by the multinomial distribution: 


N S N G 
My|N, 8 a ) ane = ( ) ane 2.90 
BNO | can ge) UA ye) LT (2.90) 
where @, is the probability that side c shows up, and 
N N! 
= Gad (2.91) 
N,...No Ni! Na! No! 


is the multinomial coefficient, which is the number of ways to divide a set of size N = S Ne 
into subsets with sizes N: up to No. If N = 1, the multinomial distribution becomes the categorical 
distribution. 

2.5.2 Softmax function 


In the conditional case, we can define 


P(y|x, 0) = Cat(y| f (x; 0)) (2.92) 
which we can also write as 
P(y|x, 0) = M(yl1, f(x; 0)) (2.93) 


We require that 0 < f.(#;0) < 1 and Sea fe(x;0)=1. 

To avoid the requirement that f directly predict a probability vector, it is common to pass the 
output from f into the softmax function [Bri90], also called the multinomial logit. This is defined 
as follows: 

ent ee 
softmax(a) + 6 ieiie (2.94) 
Xo ae esi ere 
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Figure 2.13: Logistic regression on the 3-class, 2-feature version of the Iris dataset. Adapted from Figure of 
4.25 [Gér19]. Generated by iris_ logreg.ipynb. 


This maps R© to (0, i, and satisfies the constraints that 0 < softmax(a),. < 1 and D softmax(a),. = 
1. The inputs to the softmax, a = f(x; 0), are called logits, and are a generalization of the log odds. 

The softmax function is so-called since it acts a bit like the argmax function. To see this, let us 
divide each ae by a constant T called the temperature. Then as T —> 0, we find 


1.0 if c= argmaxy ae 


0.0 otherwise (2.95) 


softmax(a/T)e = { 


In other words, at low temperatures, the distribution puts most of its probability mass in the most 
probable state (this is called winner takes all), whereas at high temperatures, it spreads the mass 
uniformly. See Figure 2.12 for an illustration. 


2.5.3 Multiclass logistic regression 


If we use a linear predictor of the form f(x; 0) = Wa + b, where W is a C x D matrix, and bisa 
C-dimensional bias vector, the final model becomes 


p(y|x; 0) = Cat(y|softmax(W-z + b)) (2.96) 
Let a = Wz + b be the C-dimensional vector of logits. Then we can rewrite the above as follows: 
ete 
ply = cæ; 8) = SO eig (2.97) 
c=1 $ 


This is known as multinomial logistic regression. 
If we have just two classes, this reduces to binary logistic regression. To see this, note that 
e% 1 

ev eu 1fen—ao 


softmax(a)o = = o (ao — a1) (2.98) 


so we can just train the model to predict a = a; — ag. This can be done with a single weight vector 
w; if we use the multi-class formulation, we will have two weight vectors, wo and w1. Such a model 
is over-parameterized, which can hurt interpretability, but the predictions will be the same. 


8. This terminology comes from the area of statistical physics. The Boltzmann distribution is a distribution over 
states which has the same form as the softmax function. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


56 Chapter 2. Probability: Univariate Models 


We discuss this in more detail in Section 10.3. For now, we just give an example. Figure 2.13 
shows what happens when we fit this model to the 3-class iris dataset, using just 2 features. We see 
that the decision boundaries between each class are linear. We can create nonlinear boundaries by 
transforming the features (e.g., using polynomials), as we discuss in Section 10.3.1. 


2.5.4 Log-sum-exp trick 


In this section, we discuss one important practical detail to pay attention to when working with 

the softmax distribution. Suppose we want to compute the normalized probability pe = p(y = c|x), 

which is given by 

Ge 

oor (2.99) 
c Z(a) a ete! 


where a = f(x;0) are the logits. We might encounter numerical problems when computing the 
partition function Z. For example, suppose we have 3 classes, with logits a = (0,1,0). Then we find 
Z = e? +e! +e? = 4.71. But now suppose a = (1000, 1001, 1000); we find Z = œœ, since on a computer, 
even using 64 bit precision, np.exp(1000)=inf. Similarly, suppose a = (—1000, —999, —1000); now 
we find Z = 0, since np.exp(-1000)=0. To avoid numerical problems, we can use the following 
identity: 


c c 
log 5 explac) = m + log $` expla — m) (2.100) 
c=1 


c=1 


This holds for any m. It is common to use m = maxe ac which ensures that the largest value you 
exponentiate will be zero, so you will definitely not overflow, and even if you underflow, the answer 
will be sensible. This is known as the log-sum-exp trick. We use this trick when implementing the 
lse function: 


c 
lse(a) £ log 5 exp(Ge) (2.101) 


We can use this to compute the probabilities from the logits: 
p(y = cla) = exp(a, — lse(a)) (2.102) 


We can then pass this to the cross-entropy loss, defined in Equation (5.41). 

However, to save computational effort, and for numerical stability, it is quite common to modify 
the cross-entropy loss so that it takes the logits @ as inputs, instead of the probability vector p. For 
example, consider the binary case. The CE loss for one example is 


L = — [I (y = 0) log po + I (y = 1) log pı] (2.103) 
where 

igp tee (Gs) siekite aeta) (2.104) 

log po = 0 — lse([0, +a]) (2.105) 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


2.6. Univariate Gaussian (normal) distribution 57 


2.6 Univariate Gaussian (normal) distribution 


The most widely used distribution of real-valued random variables y € R is the Gaussian distribu- 
tion, also called the normal distribution (see Section 2.6.4 for a discussion of these names). 


2.6.1 Cumulative distribution function 


We define the cumulative distribution function or cdf of a continuous random variable Y as 
follows: 


P(y) = Pr(Y < y) (2.106) 


(Note that we use a capital P to represent the cdf.) Using this, we can compute the probability of 
being in any interval as follows: 


Pr(a < Y < b) = P(b) — Pa) (2.107) 


Cdf’s are monotonically non-decreasing functions. 
The cdf of the Gaussian is defined by 


y 
Dwm) E f Nelmo?)dz (2.108) 


See Figure 2.2a for a plot. Note that the cdf of the Gaussian is often implemented using ®(y; u, o?) = 
$[1 + erf(z/v/2)], where z = (y — )/o and erf(u) is the error function, defined as 


erf(u) ê Z f e” dt (2.109) 


The parameter u encodes the mean of the distribution; in the case of a Gaussian, this is also 
the same as the mode. The parameter g? encodes the variance. (Sometimes we talk about the 
precision of a Gaussian, which is the inverse variance, denoted ÀA = 1/0?°.) When pu = 0 and o = 1, 
the Gaussian is called the standard normal distribution. 

If P is the cdf of Y, then P~'(q) is the value yq such that p(Y < y4) = q; this is called the q’th 
quantile of P. The value P~1(0.5) is the median of the distribution, with half of the probability 
mass on the left, and half on the right. The values P~1(0.25) and P~1(0.75) are the lower and upper 
quartiles. 

For example, let © be the cdf of the Gaussian distribution M (0, 1), and ®~' be the inverse cdf (also 
known as the probit function). Then points to the left of ®~!(a/2) contain a/2 of the probability 
mass, as illustrated in Figure 2.2b. By symmetry, points to the right of ®~!(1 — a/2) also contain 
a/2 of the mass. Hence the central interval (®~!(a/2), ®~'(1 — a/2)) contains 1 — a of the mass. If 
we set a = 0.05, the central 95% interval is covered by the range 


(=! (0.025), 6~*(0.975)) = (—1.96, 1.96) (2.110) 


If the distribution is N (u, 07), then the 95% interval becomes (u — 1.960, u + 1.960). This is often 
approximated by writing u + 20. 
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2.6.2 Probability density function 
We define the probability density function or pdf as the derivative of the cdf: 


p(y) = - 


iy (2.111) 


The pdf of the Gaussian is given by 
1 

V 210? 

where V270? is the normalization constant needed to ensure the density integrates to 1 (see 

Exercise 2.12). See Figure 2.2b for a plot. 


Given a pdf, we can compute the probability of a continuous variable being in a finite interval as 
follows: 


N(ylu,07) = e7 zaz UH)" (2.112) 


b 
Pr(a< Y < b) = f ply)dy = P(b) — P(a) (2.113) 


a 


As the size of the interval gets smaller, we can write 
Pr(y < Y < y + dy) ~ p(y)dy (2.114) 


Intuitively, this says the probability of Y being in a small interval around y is the density at y times 
the width of the interval. One important consequence of the above result is that the pdf at a point 
can be larger than 1. For example, M (0|0, 0.1) = 3.99. 

We can use the pdf to compute the mean, or expected value, of the distribution: 


A fy paddy (2.115) 


For a Gaussian, we have the familiar result that E [N(-|u,07)] = u. (Note, however, that for some 
distributions, this integral is not finite, so the mean is not defined.) 

We can also use the pdf to compute the variance of a distribution. This is a measure of the 
“spread”, and is often denoted by o?. The variance is defined as follows: 


VIY] £E[(¥ —»)?] = I (y — apud (2.116) 


= J Poodu +2 f yay 2 | weluddy =E|¥?]-— 4 (2.117) 


from which we derive the useful result 


6 [¥?] = + p? (2.118) 


The standard deviation is defined as 
std [Y] = VY [Y] =o (2.119) 


(The standard deviation can be more intepretable than the variance since it has the same units as Y 
itself.) For a Gaussian, we have the familiar result that std [N(-|u,07)] = o. 
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Figure 2.14: Linear regression using Gaussian output with mean u(x) = b+ wa and (a) fixed vari- 
ance o? (homoskedastic) or (b) input-dependent variance o(x)? (heteroscedastic). Generated by lin- 
reg_1d_hetero_ tfp.ipynb. 


2.6.3 Regression 


So far we have been considering the unconditional Gaussian distribution. In some cases, it is helpful 
to make the parameters of the Gaussian be functions of some input variables, i.e., we want to create 
a conditional density model of the form 


plylæ; 9) = N(y\ Fula; 0), fo (a; 0)") (2.120) 


where f,,(2;@) € R predicts the mean, and f,(x;6)? € R predicts the variance. 

It is common to assume that the variance is fixed, and is independent of the input. This is called 
homoscedastic regression. Furthermore it is common to assume the mean is a linear function of 
the input. The resulting model is called linear regression: 


p(yla; 0) = N(yļw" x + b, 07) (2.121) 


where 0 = (w,b,a7). See Figure 2.14(a) for an illustration of this model in 1d. and Section 11.2 for 
more details on this model. 

However, we can also make the variance depend on the input; this is called heteroskedastic 
regression. In the linear regression setting, we have 


p(yle;6) = N(ylwta + b, o4 (whx)) (2.122) 
where 0 = (w, Wo) are the two forms of regression weights, and 
o(a) = log(1 + e*) (2.123) 


is the softplus function, that maps from R to R,, to ensure the predicted standard deviation is 
non-negative. See Figure 2.14(b) for an illustration of this model in 1d. 

Note that Figure 2.14 plots the 95% predictive interval, [u(x) — 2o(x), u(x) + 20(2x)]. This is the 
uncertainty in the predicted observation y given x, and captures the variability in the blue dots. 
By contrast, the uncertainty in the underlying (noise-free) function is represented by yY [fu (æ; 4)], 
which does not involve the o term; now the uncertainty is over the parameters 0, rather than the 
output y. See Section 11.7 for details on how to model parameter uncertainty. 
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2.6.4 Why is the Gaussian distribution so widely used? 


The Gaussian distribution is the most widely used distribution in statistics and machine learning. 
There are several reasons for this. First, it has two parameters which are easy to interpret, and which 
capture some of the most basic properties of a distribution, namely its mean and variance. Second, 
the central limit theorem (Section 2.8.6) tells us that sums of independent random variables have an 
approximately Gaussian distribution, making it a good choice for modeling residual errors or “noise”. 
Third, the Gaussian distribution makes the least number of assumptions (has maximum entropy), 
subject to the constraint of having a specified mean and variance, as we show in Section 3.4.4; this 
makes it a good default choice in many cases. Finally, it has a simple mathematical form, which 
results in easy to implement, but often highly effective, methods, as we will see in Section 3.2. 

From a historical perspective, it’s worth remarking that the term “Gaussian distribution” is a bit 
misleading, since, as Jaynes [Jay03, p241] notes: “The fundamental nature of this distribution and 
its main properties were noted by Laplace when Gauss was six years old; and the distribution itself 
had been found by de Moivre before Laplace was born”. However, Gauss popularized the use of the 
distribution in the 1800s, and the term “Gaussian” is now widely used in science and engineering. 

The name “normal distribution” seems to have arisen in connection with the normal equations 
in linear regression (see Section 11.2.2.2). However, we prefer to avoid the term “normal”, since it 
suggests other distributions are “abnormal”, whereas, as Jaynes [Jay03] points out, it is the Gaussian 
that is abnormal in the sense that it has many special properties that are untypical of general 
distributions. 


2.6.5 Dirac delta function as a limiting case 


As the variance of a Gaussian goes to 0, the distribution approaches an infinitely narrow, but infinitely 
tall, “spike” at the mean. We can write this as follows: 


lim N (yu, 07) > ô(y — n) (2.124) 


where ô is the Dirac delta function, defined by 


+œ ifx=0 
lx) = f if £0 (2.125) 


where 
/ d(a)dx =1 (2.126) 


A slight variant of this is to define 


+œ ifw=y 
= 2.12 
öy(2) i o (2.127) 


Note that we have 


dy (x) = d(x — y) (2.128) 
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Figure 2.15: (a) The pdf’s for a N(0,1), T(u = 0,0 = 1,v = 1), T(u = 0,0 = 1,v = 2), and 
Laplace(0, 1/2). The mean is 0 and the variance is 1 for both the Gaussian and Laplace. When v = 1, 
the Student is the same as the Cauchy, which does not have a well-defined mean and variance. (b) Log of 
these pdf’s. Note that the Student distribution is not log-concave for any parameter value, unlike the Laplace 
distribution. Nevertheless, both are unimodal. Generated by student _laplace_pdf_plot.ipynb. 


The delta function distribution satisfies the following sifting property, which we will use later on: 


a EATEN (2.129) 


2.7 Some other common univariate distributions * 


In this section, we briefly introduce some other univariate distributions that we will use in this book. 


2.7.1 Student t distribution 


The Gaussian distribution is quite sensitive to outliers. A robust alternative to the Gaussian is 
the Student ¢-distribution, which we shall call the Student distribution for short.” Its pdf is 
as follows: 


1 2] -(4) 

ivi (! £) | (2.130) 
v o 

where p is the mean, o > 0 is the scale parameter (not the standard deviation), and v > 0 is called 


the degrees of freedom (although a better term would be the degree of normality [Kru13], since 
large values of v make the distribution act like a Gaussian). 


T (yl, 0, v) x 


9. This distribution has a colorful etymology. It was first published in 1908 by William Sealy Gosset, who worked at 
the Guinness brewery in Dublin, Ireland. Since his employer would not allow him to use his own name, he called it the 
“Student” distribution. The origin of the term t seems to have arisen in the context of tables of the Student distribution, 
used by Fisher when developing the basis of classical statistical inference. See http://jeff560.tripod.com/s.html for 
more historical details. 
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Figure 2.16: Illustration of the effect of outliers on fitting Gaussian, Student and Laplace distributions. (a) 
No outliers (the Gaussian and Student curves are on top of each other). (b) With outliers. We see that the 
Gaussian is more affected by outliers than the Student and Laplace distributions. Adapted from Figure 2.16 
of [Bis06]. Generated by robust_ pdf_plot.ipynb. 


We see that the probability density decays as a polynomial function of the squared distance from 
the center, as opposed to an exponential function, so there is more probability mass in the tail than 
with a Gaussian distribution, as shown in Figure 2.15. We say that the Student distribution has 
heavy tails, which makes it robust to outliers. 

To illustrate the robustness of the Student distribution, consider Figure 2.16. On the left, we show 
a Gaussian and a Student distribution fit to some data with no outliers. On the right, we add some 
outliers. We see that the Gaussian is affected a lot, whereas the Student hardly changes. We discuss 
how to use the Student distribution for robust linear regression in Section 11.6.2. 

For later reference, we note that the Student distribution has the following properties: 


vo? 


(v—2) 
The mean is only defined if v > 1. The variance is only defined if v > 2. For v >> 5, the Student 


distribution rapidly approaches a Gaussian distribution and loses its robustness properties. It is 
common to use v = 4, which gives good performance in a range of problems [LLT89]. 


(2.131) 


mean = u, mode = u, var = 


2.7.2 Cauchy distribution 
If v = 1, the Student distribution is known as the Cauchy or Lorentz distribution. Its pdf is defined 


by 
(4) (2.132) 


This distribution has very heavy tails compared to a Gaussian. For example, 95% of the values from 
a standard normal are between -1.96 and 1.96, but for a standard Cauchy they are between -12.7 
and 12.7. In fact the tails are so heavy that the integral that defines the mean does not converge. 
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The half Cauchy distribution is a version of the Cauchy (with u = 0) that is “folded over’ on 
itself, so all its probability density is on the positive reals. Thus it has the form 
-1 


ie (2) ] (2.133) 


This is useful in Bayesian modeling, where we want to use a distribution over positive reals with 
heavy tails, but finite density at the origin. 


2 
C+(zly) + TA 


2.7.3 Laplace distribution 


Another distribution with heavy tails is the Laplace distribution’, also known as the double 
sided exponential distribution. This has the following pdf: 


1 = 
Laplace(y|, b) = ap P ( ly 7 Ht) (2.134) 


See Figure 2.15 for a plot. Here u is a location parameter and b > 0 is a scale parameter. This 
distribution has the following properties: 


mean = u, mode = u, var = 2b? (2.135) 


In Section 11.6.1, we discuss how to use the Laplace distribution for robust linear regression, and 
in Section 11.4, we discuss how to use the Laplace distribution for sparse linear regression. 


2.7.4 Beta distribution 


The beta distribution has support over the interval [0,1] and is defined as follows: 


1 
Bet =. =a) 2.1 
eta(x|a, b) n (1-2) (2.136) 
where B(a,b) is the beta function, defined by 
T (a)r (b) 
B(a,b) = ———~~ 2.137 
(a, ) Tr(a + b) ( ) 
where T (a) is the Gamma function defined by 
ra) = I x? te "dr (2.138) 
0 


See Figure 2.17a for plots of some beta distributions. 
We require a,b > 0 to ensure the distribution is integrable (i.e., to ensure B(a,b) exists). If 
a = b = 1, we get the uniform distribution. If a and b are both less than 1, we get a bimodal 
distribution with “spikes” at 0 and 1; if a and b are both greater than 1, the distribution is unimodal. 
For later reference, we note that the distribution has the following properties (Exercise 2.8): 
a—1 ab 


a 
a de = ———— = 2.1 
ato OS cea” (a+ b)?(a+b+4+1) eno) 


mean = 


10. Pierre-Simon Laplace (1749-1827) was a French mathematician, who played a key role in creating the field of 
Bayesian statistics. 
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Figure 2.17: (a) Some beta distributions. Ifa <1, we get a “spike” on the left, and if b < 1, we get a “spike” 
on the right. ifa = b = 1, the distribution is uniform. Ifa > 1 and b > 1, the distribution is unimodal. 
Generated by beta_ dist plot.ipynb. (b) Some gamma distributions. If a < 1, the mode is at 0, otherwise the 
mode is away from 0. As we increase the rate b, we reduce the horizontal scale, thus squeezing everything 
leftwards and upwards. Generated by gamma_ dist_ plot. ipynb. 


2.7.5 Gamma distribution 


The gamma distribution is a flexible distribution for positive real valued rv’s, x > 0. It is defined 
in terms of two parameters, called the shape a > 0 and the rate b > 0: 


a 


Ga(a|shape = a,rate = b) £ n a (2.140) 
a 


Sometimes the distribution is parameterized in terms of the shape a and the scale s = 1/b: 


A 


= 1 2.141 
s*T(a) e ( ) 


Ga(x|shape = a,scale = s) 
See Figure 2.17b for some plots of the gamma pdf. 
For reference, we note that the distribution has the following properties: 


= 
i var = = (2.142) 


mean = 

There are several distributions which are just special cases of the Gamma, which we discuss below. 
e Exponential distribution. This is defined by 

Expon(z|A) ê Ga(a|shape = 1, rate = ) (2.143) 


This distribution describes the times between events in a Poisson process, i.e. a process in which 
events occur continuously and independently at a constant average rate À. 
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0.0 


(a) (b) 


Figure 2.18: Illustration of the (a) empirical pdf and (b) empirical cdf derived from a set of N = 5 samples. 
From https: //bit. ly/ 3hFgide. Used with kind permission of Mauro Escudero. 


e Chi-squared distribution. This is defined by 


1 
X(x) Ê Ga(z|shape = 5,rate = 5) (2.144) 


where v is called the degrees of freedom. This is the distribution of the sum of squared Gaussian 
random variables. More precisely, if Z; ~ N(0,1), and S = °y_, Z?, then S ~ x?. 


e The inverse Gamma distribution is defined as follows: 


be 
IG(z|shape = a,scale = b) & —~a~ + 


Tr(a) 


The distribution has these properties 


ee (2.145) 


b b b? 
P mode = var = (2.146) 
— a 


en FI (a — 1)?(a — 2) 


The mean only exists if a > 1. The variance only exists if a > 2. Note: if X ~ Ga(shape = 
a,rate = b), then 1/X ~ IG(shape = a, scale = b). (Note that b plays two different roles in this 
case.) 


2.7.6 Empirical distribution 


Suppose we have a set of N samples D = {a,...,a°%)}, derived from a distribution p(X), where 
X € R. We can approximate the pdf using a set of delta functions (Section 2.6.5) or “spikes”, centered 
on these samples: 


1 N 
pn (x) = N 5 O(n) (x) (2.147) 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


66 Chapter 2. Probability: Univariate Models 


This is called the empirical distribution of the dataset D. An example of this, with N = 5, is 
shown in Figure 2.18(a). 
The corresponding cdf is given by 


1 Š 1 Č 
p — (n) ana 
Py (x) = W 2i (x < x) =H DH (x) (2.148) 
where u,(x) is a step function at y defined by 
1 if> 
u,(x) = aera (2.149) 
0 ifa<y 


This can be visualized as a “stair case”, as in Figure 2.18(b), where the jumps of height 1/N occur at 
every sample. 


2.8 Transformations of random variables * 


Suppose x ~ p() is some random variable, and y = f(a) is some deterministic transformation of it. 
In this section, we discuss how to compute p(y). 


2.8.1 Discrete case 


If X is a discrete rv, we can derive the pmf for Y by simply summing up the probability mass for all 
the x’s such that f(x) = y: 


pyly) = XO pela) (2.150) 
x: f (x)=y 


For example, if f(X) = 1 if X is even and f(X) = 0 otherwise, and p(X) is uniform on the set 
{1,...,10}, then py(1) = $ re{2,4,6,8,10} Pe (x) = 0.5, and hence p,(0) = 0.5 also. Note that in this 
example, f is a many-to-one function. 


2.8.2 Continuous case 


If X is continuous, we cannot use Equation (2.150) since p(x) is a density, not a pmf, and we cannot 
sum up densities. Instead, we work with cdf’s, as follows: 


Py(y) = Pr(Y < y) = Pr(f(X) < y) = Pr(X € {z| f(x) < y}) (2.151) 
If f is invertible, we can derive the pdf of y by differentiating the cdf, as we show below. If f is not 
invertible, we can use numerical integration, or a Monte Carlo approximation. 
2.8.3 Invertible transformations (bijections) 


In this section, we consider the case of monotonic and hence invertible functions. (Note a function is 
invertible iff it is a bijector). With this assumption, there is a simple formula for the pdf of y, as we 
will see. (This can be generalized to invertible, but non-monotonic, functions, but we ignore this 
case.) 
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Figure 2.19: (a) Mapping a uniform pdf through the function f(x) = 2x +1. (b) Illustration of how two 
nearby points, x and x + dx, get mapped under f. If dy > 0, the function is locally increasing, but if ou <0, 
the function is locally decreasing. (In the latter case, if f(x) = y + dy, then f(x + dz) = y, since increasing x 
by dx should decrease the output by dy.) x + dz > x. From [Jan18]. Used with kind permission of Eric Jang. 


2.8.3.1 Change of variables: scalar case 


We start with an example. Suppose x ~ Unif (0,1), and y = f(x) = 2x + 1. This function stretches 
and shifts the probability distribution, as shown in Figure 2.19(a). Now let us zoom in on a point x 
and another point that is infinitesimally close, namely x + dx. We see this interval gets mapped to 
(y,y + dy). The probability mass in these intervals must be the same, hence p(x)dx = p(y)dy, and so 
p(y) = p(«)dx/dy. However, since it does not matter (in terms of probability preservation) whether 
dxz/dy > 0 or dx/dy < 0, we get 


uly) = pea) (2.152) 


Now consider the general case for any p(x) and any monotonic function f : R —> R. Let g = f7}, 
so y = f(x) and x = g(y). If we assume that f : R > R is monotonically increasing we get 


Ply) = Pr(f(X) < y) = Pr(X < fo'(y)) = Pol(f"(Y)) = Pel) (2.153) 
Taking derivatives we get 


a ad d dx d dx 
ea = —P,(x) = ——P,(r) = ~p; 2.154 
5, Pala) = FF Pala) = rle) (2.154) 


We can derive a similar expression (but with opposite signs) for the case where f is monotonically 
decreasing. To handle the general case we take the absolute value to get 


Py(y) = Pe (9(y)) a (y)| (2.155) 
This is called change of variables formula. 
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Figure 2.20: Illustration of an affine transformation applied to a unit square, f(x) = Ax +b. (a) Here 
A =I. (b) Hereb=0. From [Jan18]. Used with kind permission of Eric Jang. 


2.8.3.2 Change of variables: multivariate case 


We can extend the previous results to multivariate distributions as follows. Let f be an invertible 
function that maps R” to R”, with inverse g. Suppose we want to compute the pdf of y = f(x). By 
analogy with the scalar case, we have 


Py(¥) = Px (g(y)) | det [Jg (y)] | (2.156) 
where J} = doly) is the Jacobian of g, and | det J(y)| is the absolute value of the determinant of J 


evaluated at y. (See Section 7.8.5 for a discussion of Jacobians.) In Exercise 3.6 you will use this 
formula to derive the normalization constant for a multivariate Gaussian. 


Figure 2.20 illustrates this result in 2d, for the case where f(x) = Ax + b, where A = (; i 


b dj’ 
We see that the area of the unit square changes by a factor of det(A) = ad — bc, which is the area of 
the parallelogram. 

As another example, consider transforming a density from Cartesian coordinates x = (x1, £2) to 
polar coordinates y = f (x1, £2), so g(r, 0) = (rcos@,rsin@). Then 


Oxy = Oxy ape 
T= @ A _ e bed (2.157) 
0 


sinf rcosé 


Or 
| det(J,)| = |r cos? 0 + r sin? 0| = |r| (2.158) 
Hence 
Pr ol’, 0) = Day xa (r cos 0, r sin 0) r (2.159) 


To see this geometrically, notice that the area of the shaded patch in Figure 2.21 is given by 
Pr(r < R<r+dr,0<O<6+d60) = pr o(r,0)drd0 (2.160) 


In the limit, this is equal to the density at the center of the patch times the size of the patch, which 
is given by r dr d0. Hence 


pro(r,0) dr d9 = pz, x, (r cos 0,r sin 0) r dr d0 (2.161) 
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Figure 2.21: Change of variables from polar to Cartesian. The area of the shaded patch is r dr d0. Adapted 
from Figure 3.16 of [Ric95]. 


2.8.4 Moments of a linear transformation 


Suppose f is an affine function, so y = Aa + b. In this case, we can easily derive the mean and 
covariance of y as follows. First, for the mean, we have 


i [y] =E [Az +b] = Au +b (2.162) 


where u = E [æ]. If f is a scalar-valued function, f(x) = a'x +b, the corresponding result is 


i [a"x +b] = a'u +b (2.163) 
For the covariance, we have 
Cov [y] = Cov [Aa +b] = ASAT (2.164) 


where X = Cov |æ]. We leave the proof of this as an exercise. 
As a special case, if y = ax + b, we get 


V [y] =V [a'x +b] = a'£a (2.165) 


For example, to compute the variance of the sum of two scalar random variables, we can set a = [1, 1] 
to get 


B Xi Lye) /1 
V [zı +x} = (1 1) Ge a o (2.166) 
= X11 + Nog + 2X12 = V [x1] +V [x9] + 2Cov [x1, x2] (2.167) 
Note, however, that although some distributions (such as the Gaussian) are completely characterized 


by their mean and covariance, in general we must use the techniques described above to derive the 
full distribution of y. 
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= ie A 23 Sx 
7 6 5 =) - a = = zo = Toyo = 5 

T 6 5 = = = = zı = Toyi T ®1Y0 = 16 
7 s T 6 5 S = - £2 = LoYy2 T1Y1 Z2U0 = 34 
- - - T 6 5 - - | % =A Yy2+ T241 + L3Yo = 52 
- - - - T 6 5 - | 24 = Hoyo + egy, = 45 
= - = = = 7 6 5 £5 = T342 = 28 


Table 2.4: Discrete convolution of x = [1,2,3,4] with y = [5,6,7] to yield z = [5, 16, 34,52, 45, 28]. In general, 
Zn = VP TkYn—-k- We see that this operation consists of “flipping” y and then “dragging” it over æ, 
multiplying elementwise, and adding up the results. 


2.8.5 The convolution theorem 


Let y = 71 + £2, where x; and £2 are independent rv’s. If these are discrete random variables, we 
can compute the pmf for the sum as follows: 


ply = j) = >> (a1 = k)p(@2 = j — k) (2.168) 
k 


for j Sa 1D, ow 
If xı and x2 have pdf’s pı(xı) and po(x2), what is the distribution of y? The cdf for y is given by 


— 0o — oo 


Py(y") = Pry sy") = T pı(z1) j plone dx, (2.169) 


where we integrate over the region R defined by xı + x2 < y*. Thus the pdf for y is 


d 


p(y) = iron] = J nemu- xı)dzı (2.170) 


y*=y 


where we used the rule of differentiating under the integral sign: 


d pe db(x da(x 

Ef seoae= fo) 22 — Haa 2 (2.171) 
T Ja(x) dx dx 

We can write Equation (2.170) as follows: 

p = pı ® p2 (2.172) 


where ® represents the convolution operator. For finite length vectors, the integrals become 
sums, and convolution can be thought of as a “flip and drag” operation, as illustrated in Table 2.4. 
Consequently, Equation (2.170) is called the convolution theorem. 

For example, suppose we roll two dice, so pı and pz are both the discrete uniform distributions 
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Figure 2.22: Distribution of the sum of two dice rolls, i.e., p(y) where y = xı +22 and x; ~ Unif({1,2,...,6}). 
From https://en. wikipedia. org/wiki/Probability_ distribution. Used with kind permission of 
Wikipedia author Tim Stellmach. 


over {1,2,...,6}. Let y = 21 + x2 be the sum of the dice. We have 


p(y = 2) = p(a1 = 1)p(z2 = 1) = 2 E a (2.173) 
11 11 2 
(2.175) 


Continuing in this way, we find p(y = 4) = 3/36, p(y = 5) = 4/36, p(y = 6) = 5/36, p(y = 7) = 6/36, 
ply = 8) = 5/36, p(y = 9) = 4/36, p(y = 10) = 3/36, p(y = 11) = 2/36 and p(y = 12) = 1/36. See 
Figure 2.22 for a plot. We see that the distribution looks like a Gaussian; we explain the reasons for 
this in Section 2.8.6. 

We can also compute the pdf of the sum of two continuous rv’s. For example, in the case of 
Gaussians, where xı ~ N (u1, 0?) and x2 ~ N (H3, 03), one can show (Exercise 2.4) that if y = £1 +22 
then 


p(y) =N(x1|Hy, 77) 8 N (a2|M2, 03) = N (ylh + M2, 07 + 23) (2.176) 


Hence the convolution of two Gaussians is a Gaussian. 


2.8.6 Central limit theorem 


Now consider N random variables with pdf’s (not necessarily Gaussian) p(x), each with mean 
u and variance 0”. We assume each variable is independent and identically distributed or 
iid for short, which means X,, ~ p(X) are independent samples from the same distribution. Let 
SN = . Xn be the sum of the rv’s. One can show that, as N increases, the distribution of this 
sum approaches 


p(Sy =u) = =i exp ( = ) (2.177) 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


72 Chapter 2. Probability: Univariate Models 


0 0.5 1 0.5 1 


(a) (b) 


Figure 2.28: The central limit theorem in pictures. We plot a histogram of tin = + yy Ins, where 


ns ~ Beta(1,5), for s = 1 : 10000. As N > ov, the distribution tends towards a Gaussian. (a) N =1. (b) 
N =5. Adapted from Figure 2.6 of [Bis06]. Generated by centralLimitDemo.ipynb. 


Hence the distribution of the quantity 


z a Sn-Nu_ X-u 
a ov N a/VN 


converges to the standard normal, where X = Sy/N is the sample mean. This is called the central 
limit theorem. See e.g., [Jay03, p222] or [Ric95, p169] for a proof. 

In Figure 2.23 we give an example in which we compute the sample mean of rv’s drawn from a 
beta distribution. We see that the sampling distribution of this mean rapidly converges to a Gaussian 
distribution. 


(2.178) 


2.8.7 Monte Carlo approximation 


Suppose g is a random variable, and y = f(x) is some function of æ. It is often difficult to compute the 
induced distribution p(y) analytically. One simple but powerful alternative is to draw a large number 
of samples from the x’s distribution, and then to use these samples (instead of the distribution) to 
approximate p(y). 

For example, suppose x ~ Unif(—1,1) and y = f(x) = x?. We can approximate p(y) by drawing 
many samples from p(x) (using a uniform random number generator), squaring them, and 
computing the resulting empirical distribution, which is given by 


gee 
ps(y) & g D5 — ys) (2.179) 


This is just an equally weighted “sum of spikes”, each centered on one of the samples (see Section 2.7.6). 
By using enough samples, we can approximate p(y) rather well. See Figure 2.24 for an illustration. 

This approach is called a Monte Carlo approximation to the distribution. (The term “Monte 
Carlo” comes from the name of a famous gambling casino in Monaco.) Monte Carlo techniques were 
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Figure 2.24: Computing the distribution of y = «*, where p(x) is uniform (left). The analytic re- 
sult is shown in the middle, and the Monte Carlo approximation is shown on the right. Generated by 
change_of_vars_demo1d.ipynb. 


first developed in the area of statistical physics — in particular, during development of the atomic 
bomb — but are now widely used in statistics and machine learning as well. More details can be 
found in the sequel to this book, [Mur23], as well as specialized books on the topic, such as [Liu01; 
RC04; KTB11; BZ20]. 


2.9 Exercises 


Exercise 2.1 [Conditional independence *] 
(Source: Koller.) 


a. Let H € {1,..., K} be a discrete random variable, and let e; and e2 be the observed values of two other 
random variables E; and E2. Suppose we wish to calculate the vector 


P(H|e1,e2) = (P(H = 1ļe1, e2), ..., P(H = Kļe1, e2)) 
Which of the following sets of numbers are sufficient for the calculation? 
i. P(ei, e2), P(t), P(ei|H), P(e2|H) 
ii. P(e1,e2), P(H), P(ei,e2|H) 
ii. P(ei|H), P(e2|H), P(H) 


b. Now suppose we now assume E | F2|H (i.e., £1; and Eə are conditionally independent given H). Which 
of the above 3 sets are sufficient now? 


Show your calculations as well as giving the final result. Hint: use Bayes rule. 


Exercise 2.2 [Pairwise independence does not imply mutual independence] 


We say that two random variables are pairwise independent if 


p(X2|X1) = p(X2) (2.180) 
and hence 
p(X2, X1) = p(X1)p(X2|X1) = p(X1)p(X2) (2.181) 
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We say that n random variables are mutually independent if 


p(Xi|Xs) = p(Xi) VS C {1,...,n} \ {a} (2.182) 
and hence 
p(Xim) = [ [r (2.183) 


Show that pairwise independence between all pairs of variables does not necessarily imply mutual independence. 
It suffices to give a counter example. 


Exercise 2.3 [Conditional independence iff joint factorizes *] 
In the text we said X L Y|Z iff 


p(z, ylz) = p(2|z)p(ylz) (2.184) 


for all x,y, z such that p(z) > 0. Now prove the following alternative definition: X L Y|Z iff there exist 
functions g and h such that 


p(x, ylz) = g(x, z)h(y, z) (2.185) 
for all x,y,z such that p(z) > 0. 


Exercise 2.4 [Convolution of two Gaussians is a Gaussian] 


Show that the convolution of two Gaussians is a Gaussian, i.e., 
ply) = N (zıl, 01) 9 N (z2|u2,02) = N (y| + 2, 07 + 02) (2.186) 
where y = £1 + £2, £1 ~ N (m,02) and z2 ~ N (u2, o3) are independent rv’s. 


Exercise 2.5 [Expected value of the minimum of two rv’s *] 


Suppose X,Y are two points sampled independently and uniformly at random from the interval [0,1]. What 
is the expected location of the leftmost point? 


Exercise 2.6 [Variance of a sum] 


Show that the variance of a sum is 
V[X +Y] = V[X]+ V[Y] + 2Cov [X,Y], (2.187) 
where Cov [X,Y] is the covariance between X and Y. 


Exercise 2.7 [Deriving the inverse gamma density *] 
Let X ~ Ga(a,b), and Y = 1/X. Derive the distribution of Y. 


Exercise 2.8 [Mean, mode, variance for the beta distribution] 


Suppose 0 ~ Beta(a, b). Show that the mean, mode and variance are given by 


E=- (2.188) 

ab 
OE (aa 
mode [6] = = (2.190) 
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Exercise 2.9 [Bayes rule for medical diagnosis *| 


After your yearly checkup, the doctor has bad news and good news. The bad news is that you tested positive 
for a serious disease, and that the test is 99% accurate (i.e., the probability of testing positive given that you 
have the disease is 0.99, as is the probability of testing negative given that you don’t have the disease). The 
good news is that this is a rare disease, striking only one in 10,000 people. What are the chances that you 
actually have the disease? (Show your calculations as well as giving the final result.) 


Exercise 2.10 [Legal reasoning] 


(Source: Peter Lee.) Suppose a crime has been committed. Blood is found at the scene for which there is no 
innocent explanation. It is of a type which is present in 1% of the population. 


a. The prosecutor claims: “There is a 1% chance that the defendant would have the crime blood type if he 
were innocent. Thus there is a 99% chance that he is guilty”. This is known as the prosecutor’s fallacy. 
What is wrong with this argument? 


b. The defender claims: “The crime occurred in a city of 800,000 people. The blood type would be found in 
approximately 8000 people. The evidence has provided a probability of just 1 in 8000 that the defendant 
is guilty, and thus has no relevance.” This is known as the defender’s fallacy. What is wrong with this 
argument? 


Exercise 2.11 [Probabilities are sensitive to the form of the question that was used to generate the answer *| 


(Source: Minka.) My neighbor has two children. Assuming that the gender of a child is like a coin flip, 
it is most likely, a priori, that my neighbor has one boy and one girl, with probability 1/2. The other 
possibilities—two boys or two girls—have probabilities 1/4 and 1/4. 


a. Suppose I ask him whether he has any boys, and he says yes. What is the probability that one child is a 
girl? 


b. Suppose instead that I happen to see one of his children run by, and it is a boy. What is the probability 
that the other child is a girl? 


Exercise 2.12 [Normalization constant for a 1D Gaussian] 


The normalization constant for a zero-mean Gaussian is given by 


b y2 
z= f exp | —~— ] dx (2.191) 
3 20? 


where a = —oo and b = co. To compute this, consider its square 
b prb Di rid 
2 x+y 
Z =f T exp (- 552 ) dady (2.192) 
Let us change variables from cartesian (x,y) to polar (r,0) using x = rcos@ and y = rsin@. Since 


dxdy = rdrd@, and cos?0 + sin? 0 = 1, we have 


20 oo 2 
z = f f r exp (5) drd0 (2.193) 
i da 20? 


Evaluate this integral and hence show Z = vo?2r. Hint 1: separate the integral into a product of two terms, 

2 2 2 2 
the first of which (involving d@) is constant, so is easy. Hint 2: if u = e” /?? then du/dr = -hre [20S 
so the second integral is also easy (since f u'(r)dr = u(r)). 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


3 Probability: Multivariate Models 


3.1 Joint distributions for multiple random variables 
In this section, we discuss various ways to measure the dependence of one or more variables on each 


other. 


3.1.1 Covariance 


The covariance between two rv’s X and Y measures the degree to which X and Y are (linearly) 
related. Covariance is defined as 


Cov [X,Y] ê E[(X —E[X])(Y —E[Y])] =E[XY] -E[X]E[Y] (3.1) 


If x is a D-dimensional random vector, its covariance matrix is defined to be the following 
symmetric, positive semi definite matrix: 


Cov [æ] ê [e - 3 [æ]) (æ — E [æ])"] £ 5 (3.2) 
V [X] Cov [X1, Xə] -Cov [X1, Xp] 
E Cov [Xo, Xi] Yy [X2] -Cov [X2, Xp] (3 3) 
Cov [Xp, X1] Cov [Xp, X2] - V [Xo] 


from which we get the important result 


i [æx"] = D+ pp! (3.4) 


Another useful result is that the covariance of a linear transformation is given by 
Cov [Aw + b] = ACov [a] AT (3.5) 


as shown in Exercise 3.4. 
The cross-covariance between two random vectors is defined as 


Cov [x, y] = E [(æ — E[a])(y — E[y])"] (3.6) 
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Figure 3.1: Several sets of (x,y) points, with the correlation coefficient of x and y for each set. Note that 
the correlation reflects the noisiness and direction of a linear relationship (top row), but not the slope of 
that relationship (middle), nor many aspects of nonlinear relationships (bottom). (Note: the figure in the 
center has a slope of 0 but in that case the correlation coefficient is undefined because the variance of Y 
is zero.) From https: //en. wikipedia. org/wiki/Pearson_correlation_ coefficient. Used with kind 
permission of Wikipedia author Imagecreator. 


3.1.2 Correlation 


Covariances can be between negative and positive infinity. Sometimes it is more convenient to 
work with a normalized measure, with a finite lower and upper bound. The (Pearson) correlation 
coefficient between X and Y is defined as 


Cov [X, Y] 


p Ê corr [X,Y] = NENF 


(3.7) 


One can show (Exercise 3.2) that —1 < p < 1. 

One can also show that corr [X,Y] = 1 if and only if Y = aX +b (and a > 0) for some parameters 
a and b, i.e., if there is a linear relationship between X and Y (see Exercise 3.3). Intuitively one 
might expect the correlation coefficient to be related to the slope of the regression line, i.e., the 
coefficient a in the expression Y = aX + b. However, as we show in Equation (11.27), the regression 
coefficient is in fact given by a = Cov [X,Y] /V [X]. In Figure 3.1, we show that the correlation 
coefficient can be 0 for strong, but nonlinear, relationships. (Compare to Figure 6.6.) Thus a better 
way to think of the correlation coefficient is as a degree of linearity. (See correlation2d.ipynb for a 
demo to illustrate this idea.) 

In the case of a vector x of related random variables, the correlation matrix is given by 


1 E[(Xi=Mi)(Xe-m2)] .,, E[(Xi1-m)(Xp-up)] 
B[(X2—n2)(X1—m1)) i a E(X) Rp -un)] 
corr(x) = Ra , l aan (3.8) 
[o-an EG | i 
ODOL opo2 


This can be written more compactly as 
corr(x) = (diag(Ky2))~ ? Kaz (diag(Kye))~? (3.9) 
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Figure 3.2: Examples of spurious correlation between causally unrelated time series. Consumption of ice 
cream (red) and violent crime rate (yellow). over time. From http: // icbseverywhere. com/ blog/ 2014/ 
10/ the- Logic-of-causal-conclusions/. Used with kind permission of Barbara Drescher. 


where K,„z is the auto-covariance matrix 


Ky, = X = E [(æ — E[a])(@ — E[x])"] = Raz — pp" (3.10) 


and R,, = E [aa"] is the autocorrelation matrix. 


3.1.3 Uncorrelated does not imply independent 


If X and Y are independent, meaning p(X,Y) = p(X)p(Y), then Cov [X,Y] = 0, and hence 
corr [X, Y] = 0. So independent implies uncorrelated. However, the converse is not true: uncorrelated 
does not imply independent. For example, let X ~ U(—1,1) and Y = X?. Clearly Y is dependent on 
X (in fact, Y is uniquely determined by X), yet one can show (Exercise 3.1) that corr [X,Y] = 0. 
Some striking examples of this fact are shown in Figure 3.1. This shows several data sets where 
there is clear dependence between X and Y, and yet the correlation coefficient is 0. A more general 
measure of dependence between random variables is mutual information, discussed in Section 6.3. 
This is zero only if the variables truly are independent. 


3.1.4 Correlation does not imply causation 


It is well known that “correlation does not imply causation”. For example, consider Figure 3.2. 
In red, we plot x1.7, where 2; is the amount of ice cream sold in month t. In yellow, we plot y1:T, 
where y is the violent crime rate in month t. (Quantities have been rescaled to make the plots 
overlap.) We see a strong correlation between these signals. Indeed, it is sometimes claimed that 
“eating ice cream causes murder” [Pet13]. Of course, this is just a spurious correlation, due to a 
hidden common cause, namely the weather. Hot weather increases ice cream sales, for obvious 
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Figure 8.3: Illustration of Simpson’s paradox on the Iris dataset. (Left) Overall, y (sepal width) decreases 
with x (sepal length). (Right) Within each group, y increases with x. Generated by simpsons _ paradoz. ipynb. 


reasons. Hot weather also increases violent crime; the reason for this is hotly (ahem) debated; some 
claim it is due to an increase in anger [And01], but other claim it is merely due to more people being 
outside [Ash18], where most murders occur. 

Another famous example concerns the positive correlation between birth rates and the presence of 
storks (a kind of bird). This has given rise to the urban legend that storks deliver babies [Mat00]. 
Of course, the true reason for the correlation is more likely due to hidden factors, such as increased 
living standards and hence more food. Many more amusing examples of such spurious correlations 
can be found in [Vig15]. 

These examples serve as a “warning sign”, that we should not treat the ability for x to predict y as 
an indicator that x causes y. 


3.1.5 Simpson’s paradox 


Simpson’s paradox says that a statistical trend or relationship that appears in several different 
groups of data can disappear or reverse sign when these groups are combined. This results in 
counterintuitive behavior if we misinterpret claims of statistical dependence in a causal way. 

A visualization of the paradox is given in Figure 3.3. Overall, we see that y decreases with x, but 
within each subpopulation, y increases with x. 

For a recent real-world example of Simpson’s paradox in the context of COVID-19, consider 
Figure 3.4(a). This shows that the case fatality rate (CFR) of COVID-19 in Italy is less than in 
China in each age group, but is higher overall. The reason for this is that there are more older people 
in Italy, as shown in Figure 3.4(b). In other words, Figure 3.4(a) shows p(F = 1|A,C), where A 
is age, C is country, and F = 1 is the event that someone dies from COVID-19, and Figure 3.4(b) 
shows p(A|C), which is the probability someone is in age bucket A for country C. Combining these, 
we find p(F = 1|C = Italy) > p(F = 1|C = China). See [KGS20] for more details. 


3.2 The multivariate Gaussian (normal) distribution 
The most widely used joint probability distribution for continuous random variables is the multi- 
variate Gaussian or multivariate normal (MVN). This is mostly because it is mathematically 


convenient, but also because the Gaussian assumption is fairly reasonable in many cases (see the 
discussion in Section 2.6.4). 
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Case fatality rates (CFRs) by age group Proportion of confirmed cases by age group 
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Figure 3.4: Illustration of Simpson’s paradox using COVID-19, (a) Case fatality rates (CFRs) in Italy and 
China by age group, and in aggregated form (“Total”, last pair of bars), up to the time of reporting (see legend). 
(b) Proportion of all confirmed cases included in (a) within each age group by country. From Figure 1 of 
[KGS20]. Used with kind permission of Julius von Kügelgen. 


3.2.1 Definition 
The MVN density is defined by the following: 


N (y|, £) Ê OEE exp | sy u) E~ (y — p) (3.11) 


where u = E [y] € R® is the mean vector, and © = Cov [y] is the D x D covariance matrix, 
defined as follows: 


Cov [y] £ E [(y - E lu) - Elyl)"| (3.12) 
Yy [Y] Cov Yi, Yə] <- Cov Yi, Yp] 
Cov [Yo, Yı] y [Yə] -Cov Y2, Yp] 
= : : . : (3.13) 
Cov Yo, Yı] Cov Yp, Yə] ee Vv [Yp] 
where 
Cov [Y;, Y;] 4 E [(Y; — E [Y]; — E [Y;))] = EY] — EME] (3.14) 


and V [Y;] = Cov [Y;, Yi]. From Equation (3.12), we get the important result 
E [yy'"] =D+pp! (3.15) 


The normalization constant in Equation (3.11) Z = (2r)P/2|X|1/? just ensures that the pdf 
integrates to 1 (see Exercise 3.6). 

In 2d, the MVN is known as the bivariate Gaussian distribution. Its pdf can be represented as 
y ~N (u, ©), where y € R?, u € R? and 


a? af g? poo: 
D=(4 B)= 1 rg 3.16 
(74 oi) = \poie. o3 PN 
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full diagonal spherical 


(a) (b) (c) 


Figure 3.5: Visualization of a 2d Gaussian density as a surface plot. (a) Distribution using a full covariance 
matrix can be oriented at any angle. (b) Distribution using a diagonal covariance matrix must be parallel to 
the axis. (c) Distribution using a spherical covariance matrix must have a symmetric shape. Generated by 
gauss_plot_ 2d.ipynb. 
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Figure 3.6: Visualization of a 2d Gaussian density in terms of level sets of constant probability density. (a) A 
full covariance matrix has elliptical contours. (b) A diagonal covariance matrix is an axis aligned ellipse. 
(c) A spherical covariance matrix has a circular shape. Generated by gauss_ plot_ 2d.ipynb. 


where p is the correlation coefficient, defined by 


2 
on ey Ve aL, oe (3.17) 
VY] V [Y2] 7102 


One can show (Exercise 3.2) that —1 < corr [Y1, Y2] < 1. Expanding out the pdf in the 2d case gives 
the following rather intimidating-looking result: 


( z2 1 ( 1 X 
PAAT dR a 2ro1o2y/1 — p2 ane 2(1 — p) 
(yi — bi)? (y2 — fiz)” op sil — m) (y2 — e) 


2 pa 2 


(3.18) 
oi 02 O71 02 


Figure 3.5 and Figure 3.6 plot some MVN densities in 2d for three different kinds of covariance 
matrices. A full covariance matrix has D(D + 1)/2 parameters, where we divide by 2 since & is 
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symmetric. (The reason for the elliptical shape is explained in Section 7.4.4, where we discuss the 
geometry of quadratic forms.) A diagonal covariance matrix has D parameters, and has Os in the 
off-diagonal terms. A spherical covariance matrix, also called isotropic covariance matrix, 
has the form © = o7Ip, so it only has one free parameter, namely o?. 


3.2.2 Mahalanobis distance 


In this section, we attempt to gain some insights into the geometric shape of the Gaussian pdf 
in multiple dimensions. To do this, we will consider the shape of the level sets of constant (log) 
probability. 

The log probability at a specific point y is given by 


1 2 
log p(yļu, ©) = —5(y — p)" (y — u) + const (3.19) 


The dependence on y can be expressed in terms of the Mahalanobis distance A between y and p, 
whose square is defined as follows: 


A? £ (y — pT S"*(y— p) (3.20) 


Thus contours of constant (log) probability are equivalent to contours of constant Mahalanobis 
distance. 

To gain insight into the contours of constant Mahalanobis distance, we exploit the fact that &, 
and hence A = X`}, are both positive definite matrices (by assumption). Consider the following 
eigendecomposition (Section 7.4) of X: 


D 


= X Muu, (3.21) 
d=1 


We can similarly write 
21 
E =X uu, (3.22) 


Let us define za = u} (y — u), so z = U(y — u). Then we can rewrite the Mahalanobis distance as 
follows: 


D 
(y— pw) = (y - u) = (y - p)" £ Lul) (y-n) (3.23) 


= Į 2 z2 
=) u- mw) maui(y -u= i (3.24) 


As we discuss in Section 7.4.4, this means we can interpret the Mahalanobis distance as Euclidean 
distance in a new coordinate frame z in which we rotate y by U and scale by A. 
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For example, in 2d, let us consider the set of points (21, z2) that satisfy this equation: 


2 2 
cal Z2 
x + Xe r (3.25) 


Since these points have the same Mahalanobis distance, they correspond to points of equal probability. 
Hence we see that the contours of equal probability density of a 2d Gaussian lie along ellipses. This is 
illustrated in Figure 7.6. The eigenvectors determine the orientation of the ellipse, and the eigenvalues 
determine how elongated it is. 


3.2.3 Marginals and conditionals of an MVN * 


Suppose y = (y1, Y2) is jointly Gaussian with parameters 
= Hı y= X11 X2 A- yl — Ay Ajo 3.96 
ie e : & o2 , Ao Avo ( . ) 


where A is the precision matrix. Then the marginals are given by 


P(y1) = N (yilu, X11) 


P(y2) = N (y2|H2, £22) (3.27) 


and the posterior conditional is given by 


plyily2) = N (y1 |Hij2; X42) 
Myo = Hı + Zio Ez (Y2 — Ho) 
= m — Ajr A2(Y2 — H2) (3.28) 
= Xij (Ait — Ara(y2 — Ha) 
Xi = X1 — Zr Ez La = Air 


These equations are of such crucial importance in this book that we have put a box around them, 
so you can easily find them later. For the derivation of these results (which relies on computing the 
Schur complement 4/22 = X11 — Highs, X21), see Section 7.3.5. 

We see that both the marginal and conditional distributions are themselves Gaussian. For the 
marginals, we just extract the rows and columns corresponding to yı or yg. For the conditional, we 
have to do a bit more work. However, it is not that complicated: the conditional mean is just a 
linear function of y2, and the conditional covariance is just a constant matrix that is independent of 
Yo. We give three different (but equivalent) expressions for the posterior mean, and two different 
(but equivalent) expressions for the posterior covariance; each one is useful in different circumstances. 
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3.2.4 Example: conditioning a 2d Gaussian 


Let us consider a 2d example. The covariance matrix is 


2 
Se Pe (3.29) 
P0102 O2 


The marginal p(yı) is a 1D Gaussian, obtained by projecting the joint distribution onto the y; line: 
plui) = N (ym, 07) (3.30) 


Suppose we observe Y2 = y2; the conditional p(y1|y2) is obtained by “slicing” the joint distribution 
through the Y> = yə line: 


2 
pao po\o 
CA eG (nm pE ee) a (017a?) (3.31) 
2 2 


If 01 = 02 = 9, we get 
plyily2) =N (yilu + p(y2 — n2), o7(1 — p?)) (3.32) 


For example, suppose p = 0.8, 01 = 02 = 1, p1 = H2 = 0, and yo = 1. We see that E [yi |y2 = 1] = 
0.8, which makes sense, since p = 0.8 means that we believe that if y2 increases by 1 (beyond its 
mean), then yı increases by 0.8. We also see V [y;|yo = 1] = 1 — 0.8? = 0.36. This also makes sense: 
our uncertainty about yı has gone down, since we have learned something about yı (indirectly) by 
observing y2. If p = 0, we get p(yily2) =N (yilu, 07), since yz conveys no information about y if 
they are uncorrelated (and hence independent). 


3.2.5 Example: Imputing missing values * 


As an example application of the above results, suppose we observe some parts (dimensions) of y, 
with the remaining parts being missing or unobserved. We can exploit the correlation amongst the 
dimensions (encoded by the covariance matrix) to infer the missing entries; this is called missing 
value imputation. 

Figure 3.7 shows a simple example. We sampled N = 10 vectors from a D = 8-dimensional 
Gaussian, and then deliberately “hid” 50% of the data in each sample (row). We then inferred the 
missing entries given the observed entries and the true model parameters. More precisely, for each 
row n of the data matrix, we compute p(Yn,n|Yn,v, 9), where v are the indices of the visible entries in 
that row, h are the remaining indices of the hidden entries, and 8 = (u, ©). From this, we compute 
the marginal distribution of each missing variable i € h, p(yn,i|Yn,v,9). From the marginal, we 
compute the posterior mean, Yn, = E [yn ilYn,v, 0]. 

The posterior mean represents our “best guess” about the true value of that entry, in the sense 
that it minimizes our expected squared error, as explained in Chapter 5. We can use Y [yni|Yn,v; 9] 
as a measure of confidence in this guess, although this is not shown. Alternatively, we could draw 
multiple posterior samples from p(Yn,h|Yn,v, 9); this is called multiple imputation, and provides a 
more robust estimate to downstream algorithms that consume the “filled in” data. 


1. In practice, we would need to estimate the parameters from the partially observed data. Unfortunately the MLE 
results in Section 4.2.6 no longer apply, but we can use the EM algorithm to derive an approximate MLE in the 
presence of missing data. See the sequel to this book for details. 
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> an 37- 3 
‘ | + ee | 
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Figure 3.7: Illustration of data imputation using an MVN. (a) Visualization of the data matrix of size 
N = 10, D = 8. Blank entries are missing (not observed). Blue are positive, green are negative. Area 
of the square is proportional to the value. (This is known as a Hinton diagram, named after Geoff 
Hinton, a famous ML researcher.) (b) True data matrix (hidden). (c) Mean of the posterior predictive 
distribution, based on partially observed data in that row, using the true model parameters. Generated by 
gauss_ imputation _known_params_ demo.ipynb. 


3.3 Linear Gaussian systems * 


In Section 3.2.3, we conditioned on noise-free observations to infer the posterior over the hidden parts 
of a Gaussian random vector. In this section, we extend this approach to handle noisy observations. 

Let z € R” be an unknown vector of values, and y € RP be some noisy measurement of z. We 
assume these variables are related by the following joint distribution: 


p(z) = N(z|u;, Bz) (3.33) 
ply|z) =N(y|Wz + b, Dy) (3.34) 


where W is a matrix of size D x L. This is an example of a linear Gaussian system. 
The corresponding joint distribution, p(z, y) = p(z)p(y|z), is a L + D dimensional Gaussian, with 
mean and covariance given by 


u= iw His 5) (3.35) 
5, s,wT 
E a 5, oe ee 


By applying the Gaussian conditioning formula in Equation (3.28) to the joint p(y, z) we can 
compute the posterior p(z|y), as we explain below. This can be interpreted as inverting the z > y 
arrow in the generative model from latents to observations. 
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3.3.1 Bayes rule for Gaussians 


The posterior over the latent is given by 


p(z|y) = N (z| Hay Ezy) 
D = E+ W's) W (3.37) 


aly 


Hajy = Dely[W'E,* (y — b) + EP" p,] 


This is known as Bayes rule for Gaussians. Furthermore, the normalization constant of the 
posterior is given by 


Py) = [xem E)N (y|Wz + b, Xy)dz = N (y|W u; +b, 5y + WEW") (3.38) 


We see that the Gaussian prior p(z), combined with the Gaussian likelihood p(y|z), results in a 
Gaussian posterior p(z|y). Thus Gaussians are closed under Bayesian conditioning. To describe this 
more generally, we say that the Gaussian prior is a conjugate prior for the Gaussian likelihood, 
since the posterior distribution has the same type as the prior. We discuss the notion of conjugate 
priors in more detail in Section 4.6.1. 

In the sections below, we give various applications of this result. But first, we give the derivation. 


3.3.2 Derivation * 


We now derive Equation 3.37. The basic idea is to derive the joint distribution, p(z, y) = p(z)p(y|z), 
and then to use the results from Section 3.2.3 for computing p(zly). 

In more detail, we proceed as follows. The log of the joint distribution is as follows (dropping 
irrelevant constants): 


1 s 1 _ 
log p(z, y) = -3 ai u) £z t(z Hz) 2 (y Wz bE ‘yy - Wz- b) (3.39) 


This is clearly a joint Gaussian distribution, since it is the exponential of a quadratic form. 
Expanding out the quadratic terms involving z and y, and ignoring linear and constant terms, we 
have 


L pa 1 _ 1 E L 
Q= -37 £; lz— 59 By ty— 3 (Wz) 2, (Wz) + y7 dD) Wz (3.40) 
T =j Ty-1 Ty-1 
"e a a oa 
2\y -5 W E y 


= -> E g! (5) (3.42) 


where the precision matrix of the joint is defined as 


=i, Ty-l1 Ty-l1 
yia (> +W Ww -wls, Jsa- (A= Aw) 


-57W = Ay. Ayy (3.43) 
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From Equation 3.28, and using the fact that u, = Wy, + b, we have 


P(Zz|y) =N (Hay Ezy) 
Zay = A, =(E, + Ws, W)! 
Maly = Daly (Aces — Acy(y — My)) 
= Dey (BP ew, + WE Wy, + WE (y — Hy) 
= Daly (Byte, + W'S, (Wu, +y — By) 
= Dey (Bw, + ws. ‘(y — b)) 


3.3.2.1 Completing the square 


When working with linear Gaussian systems, it is common to use an algebraic trick called completing 


the square. In the scalar case, this says that we can write a quadratic function of the form 


f(z) = az? + ba+e 
as follows: 


az? +bz+c= alz- h}? +k 


In the vector case, this says we write a quadratic function of the form 
f(a) =a'Axv+a'lb+c 
as follows: 
x Ar+a'b+c=(a2—h)'A(a—h)+k 
1 
h =—-A7!b 
2 
lT 
k=c—-b A` b 
4 


This trick will be used in more advanced derivations. 


3.3.3 Example: Inferring an unknown scalar 


Suppose we make N noisy measurements y; of some underlying quantity z; let us assume the 


measurement noise has fixed precision A, = 1/ a7, so the likelihood is 


P(yilz) = N (yilz, Ay *) 
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(a) (b) 


Figure 3.8: Inference about z given a noisy observation y = 3. (a) Strong prior N (0,1). The posterior mean 
is “shrunk” towards the prior mean, which is 0. (b) Weak prior N (0,5). The posterior mean is similar to the 
MLE. Generated by gauss_infer_1d.ipynb. 


Now let us use a Gaussian prior for the value of the unknown source: 
p(z) = N (z|uo, 0°) (3.51) 


We want to compute p(z|y1,-.-, Yn, o°). We can convert this to a form that lets us apply Bayes 
rule for Gaussians by defining y = (y1,... yN), W = 1y (an N x 1 column vector of 1’s), and 
= = diag(A,I). Then we get 


p(zly) =N (zlun, An’) (3.52) 

Aw = Ao + Này (3.53) 
NrAyy + AoLo N Ay _ Ao 

= = l 3.54 

AN Xy Ned NA F Ao ee) 


These equations are quite intuitive: the posterior precision Ay is the prior precision Ag plus N units 
of measurement precision Ày. Also, the posterior mean uy is a convex combination of the MLE y 
and the prior mean jig. This makes it clear that the posterior mean is a compromise between the 
MLE and the prior. If the prior is weak relative to the signal strength (Ao is small relative to Ay), we 
put more weight on the MLE. If the prior is strong relative to the signal strength (Ao is large relative 
to Ay), we put more weight on the prior. This is illustrated in Figure 3.8. 

Note that the posterior mean is written in terms of NAyy, so having N measurements each of 
precision Ay is like having one measurement with value y and precision NA,. 

We can rewrite the results in terms of the posterior variance, rather than posterior precision, as 
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follows: 

p(z|D, 07) = N (z|un, TH) (3.55) 

1 ote 

2 0 
= = 3.56 
NEEL Neto? m 

0 
2 {to , Ny g? N 

= | = | 3.57 
HN 4 (8 y) Nr F o2” Nr +o?” ( ) 


where rê = 1/Xo is the prior variance and 7%, = 1/Ay is the posterior variance. 

We can also compute the posterior sequentially, by updating after each observation. If N = 1, 
we can rewrite the posterior after seeing a single observation as follows (where we define £y = o°, 
Yo = rë and X; = 77 to be the variances of the likelihood, prior and posterior): 


p(zly) = N (z|u, 21) (3.58) 
i. E S eee 
salba a 3.59 
i (5 =) Do + Ey ee 
Ho y 
SpA i 
pı TERES (3.60) 
We can rewrite the posterior mean in 3 different ways: 
diy 0 
= + 3.61 
m= Fat Say (3-61) 
Xo 
= Uo + 3.62 
Ho + (y Ho) = 73 (3.62) 
=y- (y — to) (3.63) 
=Y¥— Y— Ho Seas . 


The first equation is a convex combination of the prior and the data. The second equation is the 
prior mean adjusted towards the data. The third equation is the data adjusted towards the prior 
mean; this is called shrinkage. These are all equivalent ways of expressing the tradeoff between 
likelihood and prior. If Xo is small relative to Xy, corresponding to a strong prior, the amount of 
shrinkage is large (see Figure 3.8(a)), whereas if Xo is large relative to X}, corresponding to a weak 
prior, the amount of shrinkage is small (see Figure 3.8(b)). 

Another way to quantify the amount of shrinkage is in terms of the signal-to-noise ratio, which 
is defined as follows: 
gnp 4 EIZ] _ Zoty (3.64) 
[e°] Xy 


where z ~ N (uo, £o) is the true signal, y = z + € is the observed signal, and € ~ N (0, Ny) is the 
noise term. 


3.3.4 Example: inferring an unknown vector 


Suppose we have an unknown quantity of interest, z € R?, which we endow with a Gaussian prior, 
p(z) =N(u;,®%-). If we “know nothing” about z a priori, we can set X, = ool, which means we are 
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completely uncertain about what the value of z should be. (In practice, we can use a large but finite 
value for the covariance.) By symmetry, it seems reasonable to set uw, = 0. 

Now suppose we make N noisy but independent measurements of z, Yn ~ N(z, Ey), each of size 
D. One can show that the likelihood of N observations can be represented by a single Gaussian 
evaluated at their average, Y, provided we scale down the covariance by 1/N to compensate for the 
increased measurement precision, i.e., 


N 
p(DIz) = [] Mal, By) «Mal, By) (3.65) 


n=1 


To see why this is true, consider the case of two measurements. The log likelihood can then be 
written using canonical parameters as follows:? 


1 E E Tera _ 
log(p(yi|z)p(yalz)) = Kı — 5 (2" Dy z — 227 £7 y1) — 5 (2h z — 22D ty) 


2 2 
1 E z 
=Ki= (21257 2 — 2215" (yı + Y2)) 
1 7 eo 
=K,-5 (z'ano tz — 22'257 19) 
-= dy _) dy 
= Kə + log N(zly, re. = Kə + log N(ylz, z) 


where Kı and Kə are constants independent of z. 
Using this, and setting W = I, b = 0, we can then use Bayes rule for Gaussian to compute the 
posterior over 2: 


plzlyrn--- yn) =N (z| fi, $) (3.66) 
a71 =i -1 

2 =, +N} (3.67) 

A =Ñ (X7 (NF) + 57'p,) (3.68) 


where f and Ñ are the parameters of the posterior. 

Figure 3.9 gives a 2d example. We can think of z as representing the true, but unknown, location 
of an object in 2d space, such as a missile or airplane, and the y» as being noisy observations, such 
as radar “blips”. As we receive more blips, we are better able to localize the source. (In the sequel to 
this book, [Mur23], we discuss the Kalman filter algorithm, which extends this idea to a temporal 
sequence of observations.) 

The posterior uncertainty about each component of z location vector depends on how reliable the 
sensor is in each of these dimensions. In the above example, the measurement noise in dimension 1 is 
higher than in dimension 2, so we have more posterior uncertainty about zı (horizontal axis) than 
about zə (vertical axis). 


2. This derivation is due to Joaquin Rapela. See https: //github.com/probml/pml-book/issues/512. 
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Figure 3.9: Illustration of Bayesian inference for a 2d Gaussian random vector z. (a) The data is generated 
from yn ~ N(z, Dy), where z = [0.5,0.5]" and Sy = 0.1[2,1;1,1]). We assume the sensor noise covariance 
xy is known but z is unknown. The black cross represents z. (b) The prior is p(z) = N(z|0,0.1I2). (c) We 
show the posterior after 10 data points have been observed. Generated by gauss_infer_ 2d.ipynb. 


3.3.5 Example: sensor fusion 


In this section, we extend Section 3.3.4, to the case where we have multiple measurements, coming 
from different sensors, each with different reliabilities. That is, the model has the form 


M Nm 


p(z,y) = p(z) J| [] V@nmlz, Em) (3.69) 


m=l1n=1 


where M is the number of sensors (measurement devices), and Nm is the number of observations 
from sensor m, and y = Yi:Ni:M € R*. Our goal is to combine the evidence together, to compute 
p(zly). This is known as sensor fusion. 

We now give a simple example, where there are just two sensors, so yı ~ N (z, 1) and y2 ~ 
N(z, %2). Pictorially, we can represent this example as yı + z — y2. We can combine yı and y2 
into a single vector y, so the model can be represented as z > [y1, y2], where p(y|z) = N(y|W z, £), 
where W = [I;I] and £X, = [£1, 0; 0, £2] are block-structured matrices. We can then apply Bayes’ 
rule for Gaussians to compute p(zly). 

Figure 3.10(a) gives a 2d example, where we set X1 = Ne = 0.0112, so both sensors are equally 
reliable. In this case, the posterior mean is halfway between the two observations, yı and yo. In 
Figure 3.10(b), we set X, = 0.05I and Xz = 0.0112, so sensor 2 is more reliable than sensor 1. In 
this case, the posterior mean is closer to y2. In Figure 3.10(c), we set 


10 1 1 1 
z= 001 (4 ae D2 = 0.01 G D (3.70) 


so sensor 1 is more reliable in the second component (vertical direction), and sensor 2 is more reliable 
in the first component (horizontal direction). In this case, the posterior mean uses y1’s vertical 
component and Yy9’s horizontal component. 
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Figure 3.10: We observe yı = (0,—1) (red cross) and y2 = (1,0) (green cross) and estimate E[z|y1, y2] 
(black cross). (a) Equally reliable sensors, so the posterior mean estimate is in between the two circles. (b) 
Sensor 2 is more reliable, so the estimate shifts more towards the green circle. (c) Sensor 1 is more reliable 
in the vertical direction, Sensor 2 is more reliable in the horizontal direction. The estimate is an appropriate 
combination of the two measurements. Generated by sensor_ fusion_2d.ipynb. 


3.4 The exponential family * 


In this section, we define the exponential family, which includes many common probability 
distributions. The exponential family plays a crucial role in statistics and machine learning. In this 
book, we mainly use it in the context of generalized linear models, which we discuss in Chapter 12. 
We will see more applications of the exponential family in the sequel to this book, [Mur23]. 


3.4.1 Definition 


Consider a family of probability distributions parameterized by n € R with fixed support over 
VYP C RP. We say that the distribution p(y|7) is in the exponential family if its density can be 
written in the following way: 


plyln) ê zgo expin' T(y)] = hly) expinT (y) — A(n)) (3.71) 


where h(y) is a scaling constant (also known as the base measure, often 1), T (y) € R* are 
the sufficient statistics, 7 are the natural parameters or canonical parameters, Z(7) is a 
normalization constant known as the partition function, and A(n) = log Z (n) is the log partition 
function. One can show that A is a convex function over the convex set Q £ {n € RË : A(n) < oo}. 

It is convenient if the natural parameters are independent of each other. Formally, we say that an 
exponential family is minimal if there is no n € R* \ {0} such that n' 7 (y) = 0. This last condition 
can be violated in the case of multinomial distributions, because of the sum to one constraint on 
the parameters; however, it is easy to reparameterize the distribution using K — 1 independent 
parameters, as we show below. 

Equation (3.71) can be generalized by defining n = f(@), where @ is some other, possibly smaller, 
set of parameters. In this case, the distribution has the form 


plyg) = hly) expl EAT) — A(F(4))] (3.72) 
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If the mapping from @ to 77 is nonlinear, we call this a curved exponential family. If 7 = f(@) = ¢, 
the model is said to be in canonical form. If, in addition, T(y) = y, we say this is a natural 
exponential family or NEF. In this case, it can be written as 


p(yln) = h(y) exp[n'y — A(n)] (3.73) 


3.4.2 Example 


As a simple example, let us consider the Bernoulli distribution. We can write this in exponential 
family form as follows: 


Ber(y|u) = p” (1 — u)™” (3.74) 
= exp[y log(u) + (1 — y) log(1 — p)] (3.75) 
= exp[7 (y)"n] (3.76) 


where 7 (y) = [I (y = 1) ,I (y = 0)], n = [log(u),log(1 — u)], and p is the mean parameter. However, 
this is an over-complete representation since there is a linear dependence between the features. 
We can see this as follows: 


UT (y) =I(y=0)+I(y=1)=1 (3.77) 
If the representation is overcomplete, 7 is not uniquely identifiable. It is common to use a minimal 


representation, which means there is a unique 77 associated with the distribution. In this case, we 
can just define 


Ber(y|) = exp [vio (3 £) + log(1 w| (3.78) 


We can put this into exponential family form by defining 


n = log (5) (3.79) 
Ty) =y (3.80) 
A(n) = —log(1 — u) = log(1 + e”) (3.81) 
h(y) =1 (3.82) 


We can recover the mean parameter u from the canonical parameter 7 using 


1 


f= 9(n) 


which we recognize as the logistic (sigmoid) function. 
See the sequel to this book, [Mur23], for more examples. 
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3.4.3 Log partition function is cumulant generating function 


The first and second cumulants of a distribution are its mean E [Y] and variance V [Y], whereas the 
first and second moments are E [Y] and E [Y?]. We can also compute higher order cumulants (and 
moments). An important property of the exponential family is that derivatives of the log partition 
function can be used to generate all the cumulants of the sufficient statistics. In particular, the 
first and second cumulants are given by 


VA(n) = E [T (y)] (3.84) 
V? A(n) = Cov [T (y)] (3.85) 
From the above result, we see that the Hessian is positive definite, and hence A(n) is convex in 7. 


Since the log likelihood has the form log p(y|n) = n'T(y) — A(n) + const, we see that this is concave, 
and hence the MLE has a unique global maximum. 


3.4.4 Maximum entropy derivation of the exponential family 


Suppose we want to find a distribution p(x) to describe some data, where all we know are the 
expected values (Fķ) of certain features or functions f;,(a): 


J dee p(w) fla) = Fi (3.86) 


For example, fı might compute x, fọ might compute x”, making F, the empirical mean and F> the 
empirical second moment. Our prior belief in the distribution is q(x). 

To formalize what we mean by “least number of assumptions”, we will search for the distribution 
that is as close as possible to our prior q(x), in the sense of KL divergence (Section 6.2), while 
satisfying our constraints: 


p = argmin Dxz (p || q) subject to constraints (3.87) 
p 
If we use a uniform prior, g(x) « 1, minimizing the KL divergence is equivalent to maximizing the 
entropy (Section 6.1): 
p = argmax H(p) subject to constraints (3.88) 
p 
The result is called a maximum entropy model. 


To minimize the KL subject to the constraints in Equation (3.86), and the constraint that p(a) > 0 
and $`, p(a) = 1, we will use Lagrange multipliers (see Section 8.5.1). The Lagrangian is given by 


-- Zro MB +40 (1- Tem) Da (a-ra) e 
x k x 


We can use the calculus of variations to take derivatives wrt the function p, but we will adopt a 
simpler approach and treat p as a fixed length vector (since we are assuming that æ is discrete). 
Then we have 


OF os ig me do — So An fe(@ = c) (3.90) 
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Setting 2 = 0 for each c yields 


p(x) = 1 a (- 2 Afele ) (3.91) 


where we have defined Z £ e!t^, Using the sum-to-one constraint, we have 


1= ple) = 5 ale) io (- 2 afele ) (3.92) 


T 


Hence the normalization constant is given by 


Z =X a(x) exp (Zaro) (3.93) 
æ k 


This has exactly the form of the exponential family, where f(x) is the vector of sufficient statistics, 
—X are the natural parameters, and q(x) is our base measure. 

For example, if the features are fı(x) = x and f(x) = x”, and we want to match the first and 
second moments, we get the Gaussian disribution. 


3.5 Mixture models 


One way to create more complex probability models is to take a convex combination of simple 
distributions. This is called a mixture model. This has the form 


p(y|@) = Danh) (3.94) 


where pp is the k’th mixture component, and a, are the mixture weights which satisfy 0 < mk < 1 
and E Tr = 1. 

We can re-express this model as a hierarchical model, in which we introduce the discrete latent 
variable z € {1,..., K}, which specifies which distribution to use for generating the output y. The 
prior on this latent variable is p(z = k|@) = mp, and the conditional is p(y|z = k, 0) = p(y) = p(y|Ox).- 
That is, we define the following joint model: 


p(z|0) = Cat(z|7) (3.95) 
ply|z = k, 0) = p(y|6x) (3.96) 
where 0 = (™m,...,7K,91,...,9K) are all the model parameters. The “generative story” for the data 


is that we first sample a specific component z, and then we generate the observations y using the 
parameters chosen according to the value of z. By marginalizing out z, we recover Equation (3.94): 


JE 


K 
P(y|9) = X` ple = k/O)p(ylz = k, 0) = X mep(ylOx) (3.97) 
k=1 


We can create different kinds of mixture model by varying the base distribution px, as we illustrate 
below. 
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(b) 


Figure 3.11: A mixture of 3 Gaussians in 2d. (a) We show the contours of constant probability for each 
component in the mixture. (b) A surface plot of the overall density. Adapted from Figure 2.23 of [Bis06]. 
Generated by gmm_plot_ demo.ipynb 


-2 -2 


(a) (b) 


Figure 8.12: (a) Some data in 2d. (b) A possible clustering using K = 5 clusters computed using a GMM. 
Generated by gmm_ 2d.ipynb. 


3.5.1 Gaussian mixture models 


A Gaussian mixture model or GMM, also called a mixture of Gaussians (MoG), is defined 
as follows: 


K 
p(y|9) = XO TN (Ylh, De) (3.98) 
k=1 


In Figure 3.11 we show the density defined by a mixture of 3 Gaussians in 2d. Each mixture 
component is represented by a different set of elliptical contours. If we let the number of mixture 
components grow sufficiently large, a GMM can approximate any smooth distribution over RP. 

GMMs are often used for unsupervised clustering of real-valued data samples y, € RP. This 
works in two stages. First we fit the model e.g., by computing the MLE @ = argmax log p(D|@), where 
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0.05 0.04 0.06 0.06 0.02 
Ea sle 
0.03 0.06 0.01 0.04 0.08 
6 
0.07 0.06 0.08 0.06 0.07 
opipa 
0.04 0.02 0.03 0.06 0.03 


Figure 3.13: We fit a miature of 20 Bernoullis to the binarized MNIST digit data. We visualize the estimated 
cluster means j1,. The numbers on top of each image represent the estimated mixing weights tp. No labels 
were used when training the model. Generated by mix_bernoulli_em_ mnist.ipynb. 


D = {yn : n = 1: N}. (We discuss how to compute this MLE in Section 8.7.3.) Then we associate 
each data point y, with a discrete latent or hidden variable z, € {1,..., K} which specifies the 
identity of the mixture component or cluster which was used to generate yn. These latent identities 
are unknown, but we can compute a posterior over them using Bayes rule: 
Plzn = k|0)p(Yn|zn = k,0 

ra Ê plen = klum 0) = cp E eUn n EA) 

X psi D(2n = k'|0)p(Yn|zn = k’, 0) 

The quantity rng is called the responsibility of cluster k for data point n. Given the responsibilities, 
we can compute the most probable cluster assignment as follows: 


(3.99) 


Ên = arg maX nk = arg max [log p(Yn|zn = k, 0) + log p(zn = k|0)] (3.100) 


This is known as hard clustering. (If we use the responsibilities to fractionally assign each data 
point to different clusters, it is called soft clustering.) See Figure 3.12 for an example. 

If we have a uniform prior over zn, and we use spherical Gaussians with X = I, the hard clustering 
problem reduces to 


Zn = argmin ||yn — f,||3 (3.101) 
k 
In other words, we assign each data point to its closest centroid, as measured by Euclidean distance. 
This is the basis of the K-means clustering algorithm, which we discuss in Section 21.3. 
3.5.2 Bernoulli mixture models 


If the data is binary valued, we can use a Bernoulli mixture model or BMM (also called a 
mixture of Bernoullis), where each mixture component has the following form: 


D D 
p(y|z = k, 0) = |] | Ber(yaluax) = J [ eee — pary” (3.102) 
d=1 d=1 
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P (C=F) P(C=T) 
0.5 0.5 
C | P(S=F) | P(S=T) C | P(R=F) | P(R=T) 
0.5 0.5 (Sinker) Cran F| 08 0.2 
TI 09 0.1 T| 02 0.8 


SR | P(W=F) | P(W=T) 
FF 1.0 0.0 
TF 0.1 0.9 
FT 0.1 0.9 
TT 0.01 0.99 


Figure 3.14: Water sprinkler PGM with corresponding binary CPTs. T and F stand for true and false. 


Here uap is the probability that bit d turns on in cluster k. 

As an example, we fit a BMM using K = 20 components to the MNIST dataset (Section 3.5.2). (We 
use the EM algorithm to do this fitting, which is similar to EM for GMMs discussed in Section 8.7.3; 
however we can also use SGD to fit the model, which is more efficient for large datasets.” ) The 
resulting parameters for each mixture component (i.e., 4, and Tk) are shown in Figure 3.13. We see 
that the model has “discovered” a representation of each type of digit. (Some digits are represented 
multiple times, since the model does not know the “true” number of classes. See Section 21.3.7 for 
more information on how to choose the number K of mixture components.) 


3.6 Probabilistic graphical models * 


I basically know of two principles for treating complicated systems in simple ways: the first is 
the principle of modularity and the second is the principle of abstraction. I am an apologist 
for computational probability in machine learning because I believe that probability theory 
implements these two principles in deep and intriguing ways — namely through factorization 
and through averaging. Exploiting these two mechanisms as fully as possible seems to me to 
be the way forward in machine learning. — Michael Jordan, 1997 (quoted in [Fre98]). 


We have now introduced a few simple probabilistic building blocks. In Section 3.3, we showed 
one way to combine some Gaussian building blocks to build a high dimensional distribution p(y) 
from simpler parts, namely the marginal p(y,) and the conditional p(y2|y,). This idea can be 
extended to define joint distributions over sets of many random variables. The key assumption we 
will make is that some variables are conditionally independent of others. We will represent our 
CI assumptions using graphs, as we briefly explain below. (See the sequel to this book, [Mur23], for 
more information.) 


3. For the SGD code, see mix_bernoulli_sgd_mnist.ipynb. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


100 Chapter 3. Probability: Multivariate Models 


3.6.1 Representation 


A probabilistic graphical model or PGM is a joint probability distribution that uses a graph 
structure to encode conditional independence assumptions. When the graph is a directed acyclic 
graph or DAG, the model is sometimes called a Bayesian network, although there is nothing 
inherently Bayesian about such models. 

The basic idea in PGMs is that each node in the graph represents a random variable, and each 
edge represents a direct dependency. More precisely, each lack of edge represents a conditional 
independency. In the DAG case, we can number the nodes in topological order (parents before 
children), and then we connect them such that each node is conditionally independent of all its 
predecessors given its parents: 


Y; d Y pred(i)\pa(i) |Y pali) (3.103) 


where pa(i) are the parents of node i, and pred(i) are the predecessors of node i in the ordering. (This 
is called the ordered Markov property.) Consequently, we can represent the joint distribution as 
follows: 


Ne 
P(Y1:ne) = [eY paw) (3.104) 


i=l 


where Ng is the number of nodes in the graph. 


3.6.1.1 Example: water sprinkler network 


Suppose we want to model the dependencies between 4 random variables: C (whether it is cloudy 
season or not), R (whether it is raining or not), S (whether the water sprinkler is on or not), and W 
(whether the grass is wet or not). We know that the cloudy season makes rain more likely, so we add 
a C — R arc. We know that the cloudy season makes turning on a water sprinkler less likely, so we 
add a C > S arc. Finally, we know that either rain or sprinklers can cause the grass to get wet, so 
we add S > W and R > W edges. 

Formally, this defines the following joint distribution: 


p(C, S, R, W) = p(C)p(S|C)p( RC, 3) p(W |S, RZ) (3.105) 


where we strike through terms that are not needed due to the conditional independence properties of 
the model. 

Each term p(¥i|Ypaci)) is a called the conditional probability distribution or CPD for node 
i. This can be any kind of distribution we like. In Figure 3.14, we assume each CPD is a conditional 
categorical distribution, which can be represented as a conditional probability table or CPT. 
We can represent the ith CPT as follows: 


bije = PY: = k|Y pati) = J) (3.106) 


This satisfies the properties 0 < ijg < 1 and Ta ijk = 1 for each row j. Here 7 indexes nodes, 
i € [NG]; k indexes node states, k € [Ki], where K; is the number of states for node i; and j indexes 
joint parent states, j € [J;], where J; = [] k,,. For example, the wet grass node has 2 binary 
parents, so there are 4 parent states. 


pEpa(i) 
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yı Y2 
(a) (e) 


Figure 3.15: Illustration of first and second order autoregressive (Markov) models. 


3.6.1.2 Example: Markov chain 


Suppose we want to create a joint probability distribution over variable-length sequences, p(y1:r). If 
each variable y, represents a word from a vocabulary with K possible values, so y; € {1,..., K}, the 
resulting model represents a distribution over possible sentences of length T; this is often called a 
language model. 

By the chain rule of probability, we can represent any joint distribution over T variables as follows: 


P(yur) = P(y1)P(yaly1)P(ysly2, y1)plyalys, Y2; y1) - Tn (YelYr:t-1) (3.107) 


Unfortunately, the number of parameters needed to represent each conditional distribution p(y:|y1-4-1) 
grows exponentially with t. However, suppose we make the conditional independence assumption that 
the future, y:11.7, is independent of the past, Y1:+—1, given the present, y+. This is called the first 
order Markov condition, and is repesented by the PGM in Figure 3.15(a). With this assumption, 
we can write the joint distribution as follows: 


T 
plur) = P(yr)p(yaly)P(¥sl¥2)P(yalys) --- = p(y) | [ p(velye—1) (3.108) 


t=2 


This is called a Markov chain, Markov model or autoregressive model of order 1. 

The function p(yz|yz-1) is called the transition function, transition kernel or Markov kernel. 
This is just a conditional distribution over the states at time t given the state at time t— 1, and hence 
it satisfies the conditions p(y|y:-1) > 0 and yo p(y = klyz-1 = j) = 1. We can represent this 
CPT as a stochastic matrix, Aj, = p(y: = k|/y—1 = j), where each row sums to 1. This is known 
as the state transition matrix. We assume this matrix is the same for all time steps, so the model 
is said to be homogeneous, stationary, or time-invariant. This is an example of parameter 
tying, since the same parameter is shared by multiple variables. This assumption allows us to model 
an arbitrary number of variables using a fixed number of parameters. 

The first-order Markov assumption is rather strong. Fortunately, we can easily generalize first-order 
models to depend on the last M observations, thus creating a model of order (memory length) M 


T 


Pyi:r) =PYiim) [| plylye—ae—1) (3.109) 
t=M+1 


This is called an M’th order Markov model. For example, if M = 2, y, depends on y,_; and 
Yt—2, as shown in Figure 3.15(b). This is called a trigram model, since it models the distribution 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


102 Chapter 3. Probability: Multivariate Models 


over word triples. If we use M = 1, we get a bigram model, which models the distribution over 
word pairs. 

For large vocabulary sizes, the number of parameters needed to estimate the conditional distributions 
for M-gram models for large M can become prohibitive. In this case, we need to make additional 
assumptions beyond conditional independence. For example, we can assume that p(yt|Y:-m:+-1) can 
be represented as a low-rank matrix, or in terms of some kind of neural network. This is called a 
neural language model. See Chapter 15 for details. 


3.6.2 Inference 


A PGM defines a joint probability distribution. We can therefore use the rules of marginalization 
and conditioning to compute p(Y;|¥,; = y;) for any sets of variables i and j. Efficient algorithms to 
perform this computation are discussed in the sequel to this book, [Mur23]. 

For example, consider the water sprinkler example in Figure 3.14. Our prior belief that it has 
rained is given by p(R = 1) = 0.5. If we see that the grass is wet, then our posterior belief that it has 
rained changes to p(R = 1|W = 1) = 0.7079. Now suppose we also notice the water sprinkler was 
turned on: our belief that it rained goes down to p(R = 1|W = 1, S = 1) = 0.3204. This negative 
mutual interaction between multiple causes of some observations is called the explaining away 
effect, also known as Berkson’s paradox. (See sprinkler_pgm.ipynb for some code that reproduces 
these calculations.) 


3.6.3 Learning 


If the parameters of the CPDs are unknown, we can view them as additional random variables, add 
them as nodes to the graph, and then treat them as hidden variables to be inferred. Figure 3.16(a) 
shows a simple example, in which we have N iid random variables, yn, all drawn from the same 
distribution with common parameter 9. (The shaded nodes represent observed values, whereas the 
unshaded (hollow) nodes represent latent variables or parameters.) 

More precisely, the model encodes the following “generative story” about the data: 


Ø ~ p(0) (3.110) 
Yn ~ P(yl) (3.111) 


where p(@) is some (unspecified) prior over the parameters, and p(y|@) is some specified likelihood 
function. The corresponding joint distribution has the form 


p(D,0) = p(8)p(D|4) (3.112) 
where D = (y1,..., yn). By virtue of the iid assumption, the likelihood can be rewritten as follows: 
N 
P(P\O) = | | rnl9) (3.113) 
n=1 


Notice that the order of the data vectors is not important for defining this model, i.e., we can permute 
the numbering of the leaf nodes in the PGM. When this property holds, we say that the data is 
exchangeable. 
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yı YN Yn 
N 


Figure 3.16: Left: data points yn are conditionally independent given @. Right: Same model, using plate 
notation. This represents the same model as the one on the left, except the repeated Yn nodes are inside a 
box, known as a plate; the number in the lower right hand corner, N, specifies the number of repetitions of 
the Yn node. 


3.6.3.1 Plate notation 


In Figure 3.16(a), we see that the y nodes are repeated N times. To avoid visual clutter, it is common 
to use a form of syntactic sugar called plates. This is a notational convention in which we draw a 
little box around the repeated variables, with the understanding that nodes within the box will get 
repeated when the model is unrolled. We often write the number of copies or repetitions in the 
bottom right corner of the box. This is illustrated in Figure 3.16(b). This notation is widely used to 
represent certain kinds of Bayesian model. 

Figure 3.17 shows a more interesting example, in which we represent a GMM (Section 3.5.1) as a 
graphical model. We see that this encodes the joint distribution 


N 
[J penl) plynln Hr Exx) (3.114) 


n=1 


K 
P(Y1:N; Z1:.N; 9) = p(T) T P(Hr)p(®r) 
k=1 


We see that the latent variables z,, as well as the unknown paramters, 0 = (7, H1., D1:K), are all 
shown as unshaded nodes. 


3.7 Exercises 


Exercise 3.1 [Uncorrelated does not imply independent *] 

Let X ~ U(-1,1) and Y = X’. Clearly Y is dependent on X (in fact, Y is uniquely determined by X). 
However, show that p(X,Y) = 0. Hint: if X ~ U(a,b) then E[X] = (a+ b)/2 and Y [X] = (b — a)? /12. 
Exercise 3.2 [Correlation coefficient is between -1 and +1] 

Prove that —1 < p(X,Y) <1 


Exercise 3.3 [Correlation coefficient for linearly related variables is +1 *] 


Show that, if Y = aX + b for some parameters a > 0 and b, then p(X, Y) = 1. Similarly show that if a < 0, 
then p(X, Y) = —1. 
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Figure 3.17: A Gaussian mixture model represented as a graphical model. 


Exercise 3.4 [Linear combinations of random variables] 


Let x be a random vector with mean m and covariance matrix X. Let A and B be matrices. 


a. Derive the covariance matrix of Ax. 
b. Show that tr(AB) = tr(BA). 


c. Derive an expression for E [æT Aa]. 


Exercise 3.5 [Gaussian vs jointly Gaussian | 


Let X ~ N(0,1) and Y = WX, where p(W 1) = p(W = 1) = 0.5. It is clear that X and Y are not 
independent, since Y is a function of X. 


a. Show Y ~ NV(0,1). 


b. Show Cov [X,Y] = 0. Thus X and Y are uncorrelated but dependent, even though they are Gaussian. 
Hint: use the definition of covariance 


Cov [X,Y] =E[XY] —-E[X]E[Y] (3.115) 
and the rule of iterated expectation 
E[XY] = E [E [XY|W]] (3.116) 


Exercise 3.6 [Normalization constant for a multidimensional Gaussian] 


Prove that the normalization constant for a d-dimensional Gaussian is given by 
ErP = f epi- ME e- w))de (3.117) 
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Hint: diagonalize © and use the fact that |X| = [], A; to write the joint pdf as a product of d one-dimensional 
Gaussians in a transformed coordinate system. (You will need the change of variables formula.) Finally, use 
the normalization constant for univariate Gaussians. 


Exercise 3.7 [Sensor fusion with known variances in 1d] 
Suppose we have two sensors with known (and different) variances vı and v2, but unknown (and the same) 
mean u. Suppose we observe nı observations y? ~ N (u, v1) from the first sensor and n2 observations 


y ~ N (u, v2) from the second sensor. (For example, suppose p is the true temperature outside, and sensor 


1 is a precise (low variance) digital thermosensing device, and sensor 2 is an imprecise (high variance) mercury 
thermometer.) Let D represent all the data from both sensors. What is the posterior p(j|D), assuming a 
non-informative prior for u (which we can simulate using a Gaussian with a precision of 0)? Give an explicit 
expression for the posterior mean and variance. 

Exercise 3.8 [Show that the Student distribution can be written as a Gaussian scale mixture] 


Show that a Student distribution can be written as a Gaussian scale mixture, where we use a Gamma 
mixing distribution on the precision a, i.e. 


p(a|pu, a,b) = [Nelo Galata, b)da (3.118) 
0 


This can be viewed as an infinite mixture of Gaussians, with different precisions. 
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4.1 Introduction 


In Chapter 2—Chapter 3, we assumed all the parameters @ of our probability models were known. In 
this chapter, we discuss how to learn these parameters from data. 

The process of estimating @ from D is called model fitting, or training, and is at the heart of 
machine learning. There are many methods for producing such estimates, but most boil down to an 
optimization problem of the form 


6 = argmin L(8) (4.1) 
0 


where £(@) is some kind of loss function or objective function. We discuss several different loss 
functions in this chapter. In some cases, we also discuss how to solve the optimization problem in 
closed form. In general, however, we will need to use some kind of generic optimization algorithm, 
which we discuss in Chapter 8. 

In addition to computing a point estimate, Ô, we discuss how to model our uncertainty or 
confidence in this estimate. In statistics, the process of quantifying uncertainty about an unknown 
quantity estimated from a finite sample of data is called inference. We will discuss both Bayesian 
and frequentist approaches to inference.! 


4.2 Maximum likelihood estimation (MLE) 
The most common approach to parameter estimation is to pick the parameters that assign the highest 


probability to the training data; this is called maximum likelihood estimation or MLE. We give 
more details below, and then give a series of worked examples. 


4.2.1 Definition 
We define the MLE as follows: 


Omte = argmax p(D|@) (4.2) 
0 


1. In the deep learning community, the term “inference” refers to what we will call “prediction”, namely computing 
plylæ, 8). 


108 Chapter 4. Statistics 


We usually assume the training examples are independently sampled from the same distribution, so 
the (conditional) likelihood becomes 


N 


This is known as the iid assumption, which stands for “independent and identically distributed”. We 
usually work with the log likelihood, which is given by 


N 
¢(0) Ê log p(D|@) = X` log p(yn|an, 0) (4.4) 


n=1 
This decomposes into a sum of terms, one per example. Thus the MLE is given by 


N 
Amie = argmax 5 log p(Yn|£n, 0) (4.5) 
0 


n=1 


Since most optimization algorithms (such as those discussed in Chapter 8) are designed to minimize 
cost functions, we can redefine the objective function to be the (conditional) negative log 
likelihood or NLL: 


N 
NLL(@) = —log p(D|) = — X ` log plyn|2n, 0) (4.6) 


n=1 
Minimizing this will give the MLE. If the model is unconditional (unsupervised), the MLE becomes 


N 
nie = argmin — S log p(Yn!9) (4.7) 


n=1 


since we have outputs Yn but no inputs £n.” 


Alternatively we may want to maximize the joint likelihood of inputs and outputs. The MLE in 
this case becomes 


N 
Ênio = argmin — log p(Yn, Ln|O 4.8 
l gr Slog p( |8) (4.8) 


n=1 


4.2.2 Justification for MLE 


There are several ways to justify the method of MLE. One way is to view it as simple point 
approximation to the Bayesian posterior p(@|D) using a uniform prior, as explained in Section 4.6.7.1. 


2. In statistics, it is standard to use y to represent variables whose generative distribution we choose to model, and to 
use æ to represent exogenous inputs which are given but not generated. Thus supervised learning concerns fitting 
conditional models of the form p(y|x), and unsupervised learning is the special case where æ = Ø, so we are just fitting 
the unconditional distribution p(y). In the ML literature, supervised learning treats y as generated and æ as given, 
but in the unsupervised case, it often switches to using x to represent generated variables. 
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In particular, suppose we approximate the posterior by a delta function, p(@|D) = 6(0 — Omap)s where 
Omap is the posterior mode, given by 


Omap = argmax log p(O|D) = argmax log p(D|@) + log p(0) (4.9) 
0 0 


If we use a uniform prior, p(@) « 1, the MAP estimate becomes equal to the MLE, Omnap = Omte. 

Another way to justify the use of the MLE is that the resulting predictive distribution P(y|Omic) is 
as close as possible (in a sense to be defined below) to the empirical distribution of the data. In 
the unconditional case, the empirical distribution is defined by 


1 N 


poly) = F2 Oy — yn) (4.10) 


n=1 


We see that the empirical distribution is a series of delta functions or “spikes” at the observed training 
points. We want to create a model whose distribution q(y) = p(y|@) is similar to pp(y). 

A standard way to measure the (dis)similarity between probability distributions p and q is the 
Kullback Leibler divergence, or KL divergence. We give the details in Section 6.2, but in brief 
this is defined as 


Dra (p | 9) = tw) og Hu (4.11) 


= = Sooty )log p(y) — p(y) log a(y) (4.12) 


y 


— H(p) Hee (p,q) 


where H (p) is the entropy of p (see Section 6.1), and Hee (p, q) is the cross-entropy of p and q (see 
Section 6.1.2). One can show that Dxx (p || q) > 0, with equality iff p = q. 
If we define q(y) = p(y|@), and set p(y) = pp(y), then the KL divergence becomes 


Dex (p || 9) = XL [po (y) log pp(y) — po (y) log a(y)] (4.13) 
y : N 

—H(pp) — N 5 log p(Yn|9) (4.14) 

= const + NLL(0) (4.15) 


The first term is a constant which we can ignore, leaving just the NLL. Thus minimizing the KL is 
equivalent to minimizing the NLL which is equivalent to computing the MLE, as in Equation (4.7). 

We can generalize the above results to the supervised (conditional) setting by using the following 
empirical distribution: 


N 
polz, y) = polyla)po(2) = 5, X ôl — #n)6(y — yn) (4.16) 


n=1 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


110 Chapter 4. Statistics 


The expected KL then becomes 


pote) [Dux (pp (Y |x) || a(¥'|))] = So po (a) pou) tog PUE) (4.17) 
= const — S$ pp(w, y) log q(y|æ) (4.18) 
= const — N 5 log p(Yn|£n, 0) (4.19) 


Minimizing this is equivalent to minimizing the conditional NLL in Equation (4.6). 


4.2.3 Example: MLE for the Bernoulli distribution 


Suppose Y is a random variable representing a coin toss, where the event Y = 1 corresponds to 
heads and Y = 0 corresponds to tails. Let 6 = p(Y = 1) be the probability of heads. The probability 
distribution for this rv is the Bernoulli, which we introduced in Section 2.4. 

The NLL for the Bernoulli distribution is given by 


N 
NLL(6) = — log [J punlo) (4.20) 
m 
= —log | | #059 (1 — 0) =% (4.21) 
x n=1 
= — $ [I (yn = 1) log @ + I (yn = 0) log(1 — 0)] (4.22) 
= -[N, log + No log(1 — 8)] (4.23) 


where we have defined Nj = ane I (yn = 1) and No = 5r I (yn = 0), representing the number of 
heads and tails. (The NLL for the binomial is the same as for the Bernoulli, modulo an irrelevant 
c ) term, which is a constant independent of 0.) These two numbers are called the sufficient 
statistics of the data, since they summarize everything we need to know about D. The total count, 
N = No + Na, is called the sample size. 


The MLE can be found by solving ÆNLL(0) = 0. The derivative of the NLL is 


d _-M, M 
AU E T (4.24) 


and hence the MLE is given by 


Ny 


4.25 
No + Ni ( ) 


Ome = 


We see that this is just the empirical fraction of heads, which is an intuitive result. 
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4.2.4 Example: MLE for the categorical distribution 


Suppose we roll a K-sided dice N times. Let Y„ € {1,..., K} be the n’th outcome, where Y,, ~ Cat(@). 
We want to estimate the probabilities O from the dataset D = {yn :n = 1 : N}. The NLL is given by 


NLL(9) = — > Ny log 0x (4.26) 
k 
where Nx is the number of times the event Y = k is observed. (The NLL for the multinomial is the 
same, up to irrelevant scale factors.) 
To compute the MLE, we have to minimize the NLL subject to the constraint that 4 0, = 1. 


To do this, we will use the method of Lagrange multipliers (see Section 8.5.1).° 
The Lagrangian is as follows: 


L(0, A) £ -X Ny log 0p — À (1-5) (4.27) 
k 


k 


Taking derivatives with respect to A yields the original constraint: 


oL 

—=1- 0, = 0 4.28 

ðA 2 i (4.28) 
Taking derivatives with respect to 0, yields 

OL Nk 

—=-—4+A=0 = N= 4.29 

D Or b= AR (4.29) 


We can solve for À using the sum-to-one constraint: 
XO MSN= =A (4.30) 
k k 


Thus the MLE is given by 
6, — Ne Ne 
À N 
which is just the empirical fraction of times event k occurs. 


(4.31) 


4.2.5 Example: MLE for the univariate Gaussian 


Suppose Y ~ N (u, 07) and let D = {yn : n = 1 : N} be an iid sample of size N. We can estimate 
the parameters 0 = K a?) using MLE as follows. First, we derive the NLL, which is given by 


NLL (py, o --5 log G d exp (-spaltm = nP] (4.32) 


1 N 
2 
=. z dln ea F log(270~) (4.33) 


3. We do not need to explicitly enforce the constraint that 6; > 0 since the gradient of the Lagrangian has the form 
—Nx/0% — A; so negative values of 6, would increase the objective, rather than minimize it. (Of course, this does not 
preclude setting 0; = 0, and indeed this is the optimal solution if N;, = 0.) 
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The minimum of this function must satisfy the following conditions, which we explain in Section 8.1.1.1: 


ð E 
gN )=0, sGNLL(u, 0°) =0 (4.34) 


So all we have to do is to find this stationary point. Some simple calculus (Exercise 4.1) shows that 


the solution is given by the following: 
1 
fbn e— a> n =Y 4.35 
ji nes Yn =y (4.35) 
1 Š eee 
ĝa a= N (Yn = mie) TÈL ye ae pe = 2Yn Âme] = s? = J (4.36) 
n= scar 
1 
24 2 
aL 4.37 
sS pa Yn (4.37) 


The quantities 7 and s? are called the sufficient statistics of the data, since they are sufficient to 
compute the MLE, without loss of information relative to using the raw data itself. 
Note that you might be used to seeing the estimate for the variance written as 


N 
oy Wo 2 (Yn EE: Bie) (4.38) 


where we divide by N — 1. This is not the MLE, but is a different kind of estimate, which happens 
to be unbiased (unlike the MLE); see Section 4.7.6.1 for details.* 


4.2.6 Example: MLE for the multivariate Gaussian 


In this section, we derive the maximum likelihood estimate for the parameters of a multivariate 
Gaussian. 
First, let us write the log-likelihood, dropping irrelevant constants: 


N 
N 1 
&(m, D) = log p(D|u, 5) = > log |A| — 5 X (un — B)'A(Yn — 1) (4.39) 


n=1 


ae ee En A $ z 3 
where A = X ` is the precision matrix (inverse covariance matrix). 


4. Note that, in Python, numpy defaults to the MLE, but Pandas defaults to the unbiased estimate, as explained in 
https: //stackoverflow.com/questions/24984178/different-std-in-pandas-vs-numpy/. 
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4.2.6.1 MLE for the mean 


Using the substitution Zn = Yn — p, the derivative of a quadratic form (Equation (7.264)) and the 
chain rule of calculus, we have 


ð Ty-1 ð ty- Zn 
— b2 =— D — 4.4 
=-1(5-145-7)z, (4.41) 
since See = —I. Hence 
ə 1A N 
Ball 2) = 3 XO -25 (tn — u) = ES (yn — u) = (4.42) 
n=l n= 1 
1a 


So the MLE of u is just the empirical mean. 


4.2.6.2 MLE for the covariance matrix 


We can use the trace trick (Equation (7.36)) to rewrite the log-likelihood in terms of the precision 
matrix A = ©”! as follows: 


Kh, A) = J los |Al = 5 YO tlun = A) lun = AJA] (4.44) 


N 1 
= 5 los |A| — 5tr [SpA] (4.45) 
N 


Sy = X (un = y) (Yn = g)" = (= wna) im Nyy" (4.46) 


n=l 


where Sy is the scatter matrix centered on Ņ. 
We can rewrite the scatter matrix in a more compact form as follows: 


Sy = Y'Y = Y'CLCny = Y'CyY (4.47) 
where 
7 la T 
Cyn 4Iy- yin in (4.48) 


is the centering matrix, which converts Y to Y by subtracting the mean J = wY'l n Off every 


row. 
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1 
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sepal_length sepal_width petal_length petal_width 


(a) (b) 


Figure 4.1: (a) Covariance matrix for the features in the iris dataset from Section 1.2.1.1. (b) Correlation 
matriz. We only show the lower triangle, since the matrix is symmetric and has a unit diagonal. Compare 
this to Figure 1.3. Generated by iris_ cov_ mat.ipynb. 


Using results from Section 7.8, we can compute derivatives of the loss with respect to A to get 


Ol, A) Nir lo 
=—A = 4.4 

ðA 2 gren ey) 

— _ 1 
AT =A == Sy (4.50) 

N 
3 1 a a ee 

waa X (un — B)(Yn - F)" = wY Cry (4.51) 


Thus the MLE for the covariance matrix is the empirical covariance matrix. See Figure 4.la for an 
example. 

Sometimes it is more convenient to work with the correlation matrix defined in Equation (3.8). 
This can be computed using 


corr(Y) = (diag(S))~? X (diag(S))~? (4.52) 


where diag()~2 is a diagonal matrix containing the entries 1/a;. See Figure 4.1b for an example. 

Note, however, that the MLE may overfit or be numerically unstable, especially when the number 
of samples N is small compared to the number of dimensions D. The main problem is that © has 
O(D?) parameters, so we may need a lot of data to reliably estimate it. In particular, as we see from 
Equation (4.51), the MLE for a full covariance matrix is singular if N < D. And even when N > D, 
the MLE can be ill-conditioned, meaning it is close to singular. We discuss solutions to this problem 
in Section 4.5.2. 


4.2.7 Example: MLE for linear regression 


We briefly mentioned linear regression in Section 2.6.3. Recall that it corresponds to the following 
model: 


p(y|a; 0) = N(yļw" z, o°) (4.53) 
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where 0 = (w,a7). Let us assume for now that o? is fixed, and focus on estimating the weights w. 
The negative log likelihood or NLL is given by 


NLL(w =-b og (ses T exp (spat _ wen? )] (4.54) 


Dropping the irrelevant additive constants gives the following simplified objective, known as the 
residual sum of squares or RSS: 


N N 
RSS(w) = SO (yn — wlan)? = X` r2 (4.55) 
n=1 n=1 
where rn the n’th residual error. Scaling by the number of examples N gives the mean squared 
error or MSE: 


N 
MSE(w) = TRSS(w dln — wan)? (4.56) 


Finally, taking the square root gives the root mean squared error or RMSE: 


N 
RMSE(w) = \/MSE(w) = 7 X (Wn — wlan)? (4.57) 


n=1 


We can compute the MLE by minimizing the NLL, RSS, MSE or RMSE. All will give the same 
results, since these objective functions are all the same, up to irrelevant constants 
Let us focus on the RSS objective. It can be written in matrix notation as follows: 


N 
RSS(w) = X (yn — wl en)? = ||Xw — yll? = (Kw — y)" (Kw —y) (4.58) 


n=1 


In Section 11.2.2.1, we prove that the optimum, which occurs where V.,RSS(w) = 0, satisfies the 
following equation: 


Wmie £ argmin RSS(w) = (X! X) tX! y (4.59) 
This is called the ordinary least squares or OLS estimate, and is equivalent to the MLE. 


4.3 Empirical risk minimization (ERM) 


We can generalize MLE by replacing the (conditional) log loss term in Equation (4.6), (yn, 0; £n) = 
— log p(Yn|£n, 9), with any other loss function, to get 


1 N 
0) = N 5 L(Yn, 0; Ln) (4.60) 


This is known as empirical risk minimization or ERM, since it is the expected loss where the 
expectation is taken wrt the empirical distribution. See Section 5.4 for more details. 
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4.3.1 Example: minimizing the misclassification rate 


If we are solving a classification problem, we might want to use 0-1 loss: 


0 ify, = f(£n;0 
£01(Yn, 9; £n) = ae I (4.61) 
1 if Yn A (an; 6) 
where f(x;6) is some kind of predictor. The empirical risk becomes 
LX 
Li) => 2 to (Yn 0; £n) (4.62) 


This is just the empirical misclassification rate on the training set. 

Note that for binary problems, we can rewrite the misclassifcation rate in the following notation. 
Let y € {—1, +1} be the true label, and ĝ € {—1, +1} = f(x; 0) be our prediction. We define the 0-1 
loss as follows: 


f(y, 0) =I (9 AG) =1G9 <0) (4.63) 


The corresponding empirical risk becomes 


N N 
£(0) = ~ XO f01(Yns In) = ~ XCI (On Ôn < 9) (4.64) 


n=1 n=1 


where the dependence on £n and @ is implicit. 


4.3.2 Surrogate loss 


Unfortunately, the 0-1 loss used in Section 4.3.1 is a non-smooth step function, as shown in Figure 4.2, 
making it difficult to optimize. (In fact, it is NP-hard [BDEL03].) In this section we consider the use 
of a surrogate loss function [BJM06]. The surrogate is usually chosen to be a maximally tight 
convex upper bound, which is then easy to minimize. 

For example, consider a probabilistic binary classifier, which produces the following distribution 
over labels: 


plx, 0) = o(yn) = Irem (4.65) 
where 7 = f(æ;0) is the log odds. Hence the log loss is given by 
lu(G,n) = — log p(G|n) = log(1 + e~%”) (4.66) 


Figure 4.2 shows that this is a smooth upper bound to the 0-1 loss, where we plot the loss vs the 
quantity yn, known as the margin, since it defines a “margin of safety” away from the threshold 
value of 0. Thus we see that minimizing the negative log likelihood is equivalent to minimizing a 
(fairly tight) upper bound on the empirical 0-1 loss. 

Another convex upper bound to 0-1 loss is the hinge loss, which is defined as follows: 


Lainge(Y, n) = max(0, 1— gn) = (1 ~~ gn)+ (4.67) 


This is plotted in Figure 4.2; we see that it has the shape of a partially open door hinge. This is 
convex upper bound to the 0-1 loss, although it is only piecewise differentiable, not everywhere 
differentiable. 
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s — 0-1loss 
ia sees hinge loss 
2.5 ES —-- log loss 


Ke — exp loss 


Figure 4.2: Illustration of various loss functions for binary classification. The horizontal axis is the margin 
z = Jn, the vertical axis is the loss. 0-1 loss is 1(z < 0). Hinge-loss is max(0, 1— z). Log-loss is log,(1+e *). 
Exp-loss is e~*. Generated by hinge_loss_plot.ipynb. 


4.4 Other estimation methods * 


4.4.1 The method of moments 


Computing the MLE requires solving the equation VeaNLL(@) = 0. Sometimes this is computationally 
difficult. In such cases, we may be able to use a simpler approach known as the method of moments 
(MOM). In this approach, we equate the theoretical moments of the distribution to the empirical 
moments, and solve the resulting set of K simultaneous equations, where K is the number of 
parameters. The theoretical moments are given by uk = E [YF], for k = 1 : K, and the empirical 
moments are given by 


: US a 
fin = 5 DU (4.68) 
n=1 


so we just need to solve uk = fix, for each k. We give some examples below. 

The method of moments is simple, but it is theoretically inferior to the MLE approach, since it 
may not use all the data as efficiently. (For details on these theoretical results, see e.g., [CB02].) 
Furthermore, it can sometimes produce inconsistent results (see Section 4.4.1.2). However, when it 
produces valid estimates, it can be used to initialize iterative algorithms that are used to optimize the 
NLL (see e.g., [AHK12]), thus combining the computational efficiency of MOM with the statistical 
accuracy of MLE. 


4.4.1.1 Example: MOM for the univariate Gaussian 

For example, consider the case of a univariate Gaussian distribution. From Section 4.2.5, we have 
Mm=h=7 (4.69) 
fg = 07 +p? = 8? (4.70) 
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where 7 is the empirical mean and s? is the empirical average sum of squares. so fi = 7 and 
ô? = s? — 7°. In this case, the MOM estimate is the same as the MLE, but this is not always the 
case. 


4.4.1.2 Example: MOM for the uniform distribution 


In this section, we give an example of the MOM applied to the uniform distribution. Our presentation 
follows the wikipedia page.” Let Y ~ Unif(6, 92) be a uniform random variable, so 


1 


p(y\@) = 7-— I (01 < y < b2) (4.71) 
02 — 0; 


The first two moments are 


b= [Y] = 0i + 02) (4.72) 


NI = 


1 
3 


Inverting these equations gives 


(01,02) = (1 — y 3(u2 — HZ), 241 — s ) (4.74) 


Unfortunately this estimator can sometimes give invalid results. For example, suppose D = 


{0,0,0,0,1}. The empirical moments are ji; = E and fig = E, so the estimated parameters are 


0 = E — ays = —0.493 and 6) = + + 2v3 = 0.893. However, these cannot possibly be the correct 


parameters, since if 02 = 0.893, we cannot generate a sample as large as 1. 
By contrast, consider the MLE. Let ya) < ya) < +++ < yy be the order statistics of the data 
(i.e., the values sorted in increasing order). Let 0 = 02 — 01. Then the likelihood is given by 


p(D|@) = (0) NI (ya) = 01) I (y) < 42) (4.75) 


Within the permitted bounds for 0, the derivative of the log likelihood is given by 


d N 
log p(D|9) = -+ 4. 
70 og p(D|@) 7 <0 (4.76) 


Hence the likelihood is a decreasing function of 0, so we should pick 
6, = Ya) 65 = YN) (4.77) 
In the above example, we get 6, = 0 and 6) = 1, as one would expect. 


5. https: //en.wikipedia.org/wiki/Method_of_moments_(statistics). 
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4.4.2 Online (recursive) estimation 


If the entire dataset D is available before training starts, we say that we are doing batch learning. 
However, in some cases, the data set arrives sequentially, so D = {y1, Y2,...} in an unbounded 
stream. In this case, we want to perform online learning. 

Let Ôi be our estimate (e.g., MLE) given D,.,-1. To ensure our learning algorithm takes constant 
time per update, we need to find a learning rule of the form 


9: = f (01-1, y) (4.78) 


This is called a recursive update. Below we give some examples of such online learning methods. 


4.4.2.1 Example: recursive MLE for the mean of a Gaussian 


Let us reconsider the example from Section 4.2.5 where we computed the MLE for a univariate 
Gaussian. We know that the batch estimate for the mean is given by 


t 
n 1 
=F 5 Yn (4.79) 
n=1 
This is just a running sum of the data, so we can easily convert this into a recursive estimate as 
follows: 
t 


f= = om = 


n=1 


T 1 N 
= ĝi + 5 (Ut — Bea) (4.81) 


This is known as a moving average. 

We see from Equation (4.81) that the new estimate is the old estimate plus a correction term. The 
size of the correction diminishes over time (i.e., as we get more samples). However, if the distribution 
is changing, we want to give more weight to more recent data examples. We discuss how to do this 
in Section 4.4.2.2. 


tle 


((t— 1)fy_1 + yt) (4.80) 


4.4.2.2 Exponentially-weighted moving average 


Equation (4.81) shows how to compute the moving average of a signal. In this section, we show 
how to adjust this to give more weight to more recent examples. In particular, we will compute 
the following exponentially weighted moving average or EWMA, also called an exponential 
moving average or EMA: 


fy, = Phi + (1— B)ye (4.82) 


where 0 < 8 < 1. The contribution of a data point k steps in the past is weighted by 6*(1 — 8). 
Thus the contribution from old data is exponentially decreasing. In particular, we have 


fy, = Bhi + (1 — B)ye (4.83) 
= B’ mi- + B -— byi- + (1-— B)y: (4.84) 
= tyo + (1 — 8)B ty. +-+ (1- B)By_1 + (1- 8y (4.85) 
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beta = 0.90 beta = 0.99 


2] — EMA 2| — EMA 
— EMA with bias correction — EMA with bias correction 


Figure 4.8: Illustration of exponentially-weighted moving average with and without bias correction. (a) Short 
memory: B = 0.9. (a) Long memory: B = 0.99. Generated by ema_demo.ipynb. 


The sum of a geometric series is given by 


ata ptla..4pt4 goo ich (4.86) 
T T 1-8 . 
Hence 
t 1— gett 
(1-6) X B* =(1- 8B) aoe (4.87) 
k=0 


Since 0 < 6 < 1, we have 6+! — 0 as t — 00, so smaller 8 forgets the past more quickly, and adapts 
to the more recent data more rapidly. This is illustrated in Figure 4.3. 

Since the initial estimate starts from fig = 0, there is an initial bias. This can be corrected by 
scaling as follows [KB15]: 


~ h 
Be 7 g 


(Note that the update in Equation (4.82) is still applied to the uncorrected EMA, f4,_,, before being 
corrected for the current time step.) The benefit of this is illustrated in Figure 4.3. 


(4.88) 


4.5 Regularization 


A fundamental problem with MLE, and ERM, is that it will try to pick parameters that minimize 
loss on the training set, but this may not result in a model that has low loss on future data. This is 
called overfitting. 

As a simple example, suppose we want to predict the probability of heads when tossing a coin. We 
toss it N = 3 times and observe 3 heads. The MLE is 6nte = Ni/(No + Ni) = 3/(3 + 0) = 1 (see 
Section 4.2.3). However, if we use Ber(y|Omie) to make predictions, we will predict that all future 
coin tosses will also be heads, which seems rather unlikely. 
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The core of the problem is that the model has enough parameters to perfectly fit the observed 
training data, so it can perfectly match the empirical distribution. However, in most cases the 
empirical distribution is not the same as the true distribution, so putting all the probability mass on 
the observed set of N examples will not leave over any probability for novel data in the future. That 
is, the model may not generalize. 

The main solution to overfitting is to use regularization, which means to add a penalty term to 
the NLL (or empirical risk). Thus we optimize an objective of the form 


N 


L£(6; A) = È 5 L(Yn, 9; £n) 


n=l 


+AC(0) (4.89) 


where A > 0 is the regularization parameter, and C (0) is some form of complexity penalty. 
A common complexity penalty is to use C(@) = — log p(0), where p(@) is the prior for 0. If £ is 
the log loss, the regularized objective becomes 


N 
£(0;)) = — log p(n len, 8) — Alogp(0) (4.90) 


n=1 
By setting À = 1 and rescaling p(@) appropriately, we can equivalently minimize the following: 


N 
L(0; A) = — | X` log p(Yn|an, 9) + log p(@)| = — [log p(D|O) + log p(4)} (4.91) 


n=1 


Minimizing this is equivalent to maximizing the log posterior: 


6 = argmax log p(0 |D) = argmax [log p(D|@) + log p(@) — const] (4.92) 
0 0 
This is known as MAP estimation, which stands for maximum a posterior estimation. 


4.5.1 Example: MAP estimation for the Bernoulli distribution 


Consider again the coin tossing example. If we observe just one head, the MLE is mie = 1, which 
predicts that all future coin tosses will also show up heads. To avoid such overfitting, we can add a 
penalty to 0 to discourage “extreme” values, such as 0 = 0 or 0 = 1. We can do this by using a beta 
distribution as our prior, p(@) = Beta(6|a, b), where a,b > 1 encourages values of 0 near to a/(a+ b) 
(see Section 2.7.4 for details). The log likelihood plus log prior becomes 


£(0) = log p(D|6) + log p(@) (4.93) 
= [Nj log 6 + No log(1 — 4)] + [(a — 1) log(@) + (b — 1) log(1 — 0)] (4.94) 


Using the method from Section 4.2.3 we find that the MAP estimate is 


Ni +a-1 
Ni +Not+a+b—2 


(4.95) 


Omap = 
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If we set a = b = 2 (which weakly favors a value of 0 near 0.5), the estimate becomes 


Nı +1 


Onap = =~ 4. 
P Ni+No+2 (a 


This is called add-one smoothing, and is a simple but widely used technique to avoid the zero 
count problem. (See also Section 4.6.2.9.) 

The zero-count problem, and overfitting more generally, is analogous to a problem in philosophy 
called the black swan paradox. This is based on the ancient Western conception that all swans 
were white. In that context, a black swan was a metaphor for something that could not exist. (Black 
swans were discovered in Australia by European explorers in the 17th Century.) The term “black 
swan paradox” was first coined by the famous philosopher of science Karl Popper; the term has also 
been used as the title of a recent popular book [Tal07]. This paradox was used to illustrate the 
problem of induction, which is the problem of how to draw general conclusions about the future 
from specific observations from the past. The solution to the paradox is to admit that induction is in 
general impossible, and that the best we can do is to make plausible guesses about what the future 
might hold, by combining the empirical data with prior knowledge. 


4.5.2 Example: MAP estimation for the multivariate Gaussian * 


In Section 4.2.6, we showed that the MLE for the mean of an MVN is the empirical mean, ftmie = Y. 
We also showed that the MLE for the covariance is the empirical covariance, y= + Sy. 

In high dimensions the estimate for X can easily become singular. One solution to this is to 
perform MAP estimation, as we explain below. 


4.5.2.1 Shrinkage estimate 


A convenient prior to use for © is the inverse Wishart prior. This is a distribution over positive 
definite matrices, where the parameters are defined in terms of a prior scatter matrix, S, and a prior 
sample size or strength Ñ. One can show that the resulting MAP estimate is given by 


z Š +S, N 5 N Sp ` 
Dira = Y= we os ¥ AX +(1-— AÀ b2 e 4.97 
P N+N Ñ+ANĂ Ñ+4NN o+( ) Zm (aar) 


where \ = ae K controls the amount of regularization. 


A common choice (see e.g., [FR07, p6]) for the prior scatter matrix is to use S=N diag($ mie). 
With this choice, we find that the MAP estimate for X is given by 


3 a Vnte(t, J) if i =j 
Dia ; = ^ Pa . 4. 
pd) { (1—A)Emie(i,7) otherwise (4.98) 


Thus we see that the diagonal entries are equal to their ML estimates, and the off-diagonal elements 
are “shrunk” somewhat towards 0. This technique is therefore called shrinkage estimation. 

The other parameter we need to set is A, which controls the amount of regularization (shrink- 
age towards the MLE). It is common to set À by cross validation (Section 4.5.5). Alternatively, 
we can use the closed-form formula provided in [LW04a; LW04b; SS05], which is the optimal 
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Figure 4.4: Estimating a covariance matrix in D = 50 dimensions using N € {100,50,25} samples. We 
plot the eigenvalues in descending order for the true covariance matrix (solid black), the MLE (dotted blue) 
and the MAP estimate (dashed red), using Equation (4.98) with A = 0.9. We also list the condition number 
of each matrix in the legend. We see that the MLE is often poorly conditioned, but the MAP estimate is 
numerically well behaved. Adapted from Figure 1 of [SS05]. Generated by shrinkcov_plots.ipynb. 


frequentist estimate if we use squared loss. This is implemented in the sklearn function https: //scikit- 
learn.org/stable /modules/generated/sklearn.covariance.Ledoit Wolf. html. 

The benefits of this approach are illustrated in Figure 4.4. We consider fitting a 50-dimensional 
Gaussian to N = 100, N = 50 and N = 25 data points. We see that the MAP estimate is always 
well-conditioned, unlike the MLE (see Section 7.1.4.4 for a discussion of condition numbers). In 
particular, we see that the eigenvalue spectrum of the MAP estimate is much closer to that of the 
true matrix than the MLE’s spectrum. The eigenvectors, however, are unaffected. 


4.5.3 Example: weight decay 


In Figure 1.7, we saw how using polynomial regression with too high of a degree can result in 
overfitting. One solution is to reduce the degree of the polynomial. However, a more general solution 
is to penalize the magnitude of the weights (regression coefficients). We can do this by using a 
zero-mean Gaussian prior, p(w). The resulting MAP estimate is given by 


Wmap = argmin NLL(w) + Al|w||3 (4.99) 


where ||w]||3 = sar w3. (We write w rather than 0, since it only really make sense to penalize the 
magnitude of weight vectors, rather than other parameters, such as bias terms or noise variances.) 

Equation (4.99) is called 4&2 regularization or weight decay. The larger the value of A, the more 
the parameters are penalized for being “large” (deviating from the zero-mean prior), and thus the 
less flexible the model. 

In the case of linear regression, this kind of penalization scheme is called ridge regression. For 
example, consider the polynomial regression example from Section 1.2.2.2, where the predictor has 
the form 


D 
f(x; w) =X was’ a ay (4.100) 
d=0 
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Figure 4.5: (a-c) Ridge regression applied to a degree 14 polynomial fit to 21 datapoints. (d) MSE vs strength 
of regularizer. The degree of regularization increases from left to right, so model complexity decreases from 
left to right. Generated by linreg_poly_ ridge.ipynb. 


Suppose we use a high degree polynomial, say D = 14, even though we have a small dataset with 
just N = 21 examples. MLE for the parameters will enable the model to fit the data very well, by 
carefully adjusting the weights, but the resulting function is very “wiggly”, thus resulting in overfitting. 
Figure 4.5 illustrates how increasing À can reduce overfitting. For more details on ridge regression, 
see Section 11.3. 


4.5.4 Picking the regularizer using a validation set 


A key question when using regularization is how to choose the strength of the regularizer A: a small 
value means we will focus on minimizing empirical risk, which may result in overfitting, whereas a 
large value means we will focus on staying close to the prior, which may result in underfitting. 

In this section, we describe a simple but very widely used method for choosing A. The basic idea 
is to partition the data into two disjoint sets, the training set Dirain and a validation set Dyajiq 
(also called a development set). (Often we use about 80% of the data for the training set, and 
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Figure 4.6: Schematic of 5-fold cross validation. 


20% for the validation set.) We fit the model on Dirain (for each setting of A) and then evaluate its 
performance on Dyaliq. We then pick the value of A that results in the best validation performance. 
(This optimization method is a 1d example of grid search, discussed in Section 8.8.) 

To explain the method in more detail, we need some notation. Let us define the regularized 
empirical risk on a dataset as follows: 

1 
RAD) = 5j 2, Mus F(@i9)) + AC(B) (4.101) 
(x2,y)ED 


For each A, we compute the parameter estimate 


Oa (Dirain) = argmin Ra (0, Dtrain) (4.102) 
0 


We then compute the validation risk: 
RY! £ Ro(O (Dtrain); Dyalia) (4.103) 


This is an estimate of the population risk, which is the expected loss under the true distribution 
p* (x,y). Finally we pick 


d* = argmin Ry! (4.104) 
AES 
(This requires fitting the model once for each value of in S, although in some cases, this can be 


done more efficiently.) 
After picking A*, we can refit the model to the entire dataset, D = Dtrain U Dvatia, to get 


6* = argmin R)-(0,D) (4.105) 
6 


4.5.5 Cross-validation 


The above technique in Section 4.5.4 can work very well. However, if the size of the training set 
is small, leaving aside 20% for a validation set can result in an unreliable estimate of the model 
parameters. 
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A simple but popular solution to this is to use cross validation (CV). The idea is as follows: we 
split the training data into K folds; then, for each fold k € {1,..., K}, we train on all the folds but 
the k’th, and test on the k’th, in a round-robin fashion, as sketched in Figure 4.6. Formally, we have 


1 K 


RY 4 K 5 Ro(ĝx (D-r), Dr) (4.106) 
k=1 

where Dy, is the data in the k’th fold, and D_, is all the other data. This is called the cross-validated 
risk. Figure 4.6 illustrates this procedure for K = 5. If we set K = N, we get a method known as 
leave-one-out cross-validation, since we always train on N — 1 items and test on the remaining 
one. 

We can use the CV estimate as an objective inside of an optimization routine to pick the optimal 
hyperparameter, \ = argmin, RSY. Finally we combine all the available data (training and validation), 


and re-estimate the model parameters using Ô = argming R;(@,D). See Section 5.4.3 for more details. 


4.5.5.1 The one standard error rule 


CV gives an estimate of Ry, but does not give any measure of uncertainty. A standard frequentist 
measure of uncertainty of an estimate is the standard error of the mean, which is the mean of 
the sampling distribution of the estimate (see Section 4.7.1). We can compute this as follows. First 
let Ln = (yn, f (an; Ô, (D_n)) be the loss on the wth example, where we use the parameters that 
were estimated using whichever training fold excludes n. (Note that Ln depends on A, but we drop 
this from the notation.) Next let f= 4 5i Ln be the empirical mean and ô? = + Se — p)? 
be the empirical variance. Given this, we define our estimate to be ĝ, and the standard error of 
this estimate to be se(ĝ) = Sr Note that o measures the intrinsic variability of L across samples, 
whereas se(/i) measures our uncertainty about the mean ji. 

Suppose we apply CV to a set of models and compute the mean and se of their estimated risks. 
A common heuristic for picking a model from these noisy estimates is to pick the value which 
corresponds to the simplest model whose risk is no more than one standard error above the risk of 
the best model; this is called the one-standard error rule [HTF01, p216]. 


4.5.5.2 Example: ridge regression 


As an example, consider picking the strength of the £2 regularizer for the ridge regression problem 
in Section 4.5.3. In Figure 4.7a, we plot the error vs log(A) on the train set (blue) and test set 
(red curve). We see that the test error has a U-shaped curve, where it decreases as we increase the 
regularizer, and then increases as we start to underfit. In Figure 4.7b, we plot the 5-fold CV estimate 
of the test MSE vs log(A). We see that the minimum CV error is close the optimal value for the test 
set (although it does underestimate the spike in the test error for large lambda, due to the small 
sample size.) 


4.5.6 Early stopping 


A very simple form of regularization, which is often very effective in practice (especially for complex 
models), is known as early stopping. This leverages the fact that optimization algorithms are 
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Figure 4.7: Ridge regression is applied to a degree 14 polynomial fit to 21 datapoints shown in Figure 4.5 
for different values of the regularizer A. The degree of regularization increases from left to right, so model 
complexity decreases from left to right. (a) MSE on train (blue) and test (red) vs log(A). (b) 5-fold cross- 
validation estimate of test MSE; error bars are standard error of the mean. Vertical line is the point chosen 
by the one standard error rule. Generated by polyfitRidgeCV.ipynb. 


iterative, and so they take many steps to move away from the initial parameter estimates. If we detect 
signs of overfitting (by monitoring performance on the validation set), we can stop the optimization 
process, to prevent the model memorizing too much information about the training set. See Figure 4.8 
for an illustration. 


4.5.7 Using more data 


As the amount of data increases, the chance of overfitting (for a model of fixed complexity) decreases 
(assuming the data contains suitably informative examples, and is not too redundant). This is 
illustrated in Figure 4.9. We show the MSE on the training and test sets for four different models 
(polynomials of increasing degree) as a function of the training set size N. (A plot of error vs training 
set size is known as a learning curve.) The horizontal black line represents the Bayes error, which 
is the error of the optimal predictor (the true model) due to inherent noise. (In this example, the 
true model is a degree 2 polynomial, and the noise has a variance of g? = 4; this is called the noise 
floor, since we cannot go below it.) 

We notice several interesting things. First, the test error for degree 1 remains high, even as N 
increases, since the model is too simple to capture the truth; this is called underfitting. The test 
error for the other models decreases to the optimal level (the noise floor), but it decreases more 
rapidly for the simpler models, since they have fewer parameters to estimate. The gap between the 
test error and training error is larger for more complex models, but decreases as N grows. 

Another interesting thing we can note is that the training error (blue line) initially increases with 
N, at least for the models that are sufficiently flexible. The reason for this is as follows: as the data 
set gets larger, we observe more distinct input-output pattern combinations, so the task of fitting the 
data becomes harder. However, eventually the training set will come to resemble the test set, and 
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Figure 4.8: Performance of a text classifier (a neural network applied to a bag of word embeddings using 
average pooling) vs number of training epochs on the IMDB movie sentiment dataset. Blue = train, red = 
validation. (a) Cross entropy loss. Early stopping is triggered at about epoch 25. (b) Classification accuracy. 


Generated by imdb_ mlp_bow_ tf.ipynb. 
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Figure 4.9: MSE on training and test sets vs size of training set, for data generated from a degree 2 polynomial 


with Gaussian noise of variance o° 


by linreg_poly_vs_n.ipynb. 
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the error rates will converge, and will reflect the optimal performance of that model. 


4.6 Bayesian statistics * 


So far, we have discussed several ways to estimate parameters from data. However, these approaches 
ignore any uncertainty in the estimates, which can be important for some applications, such as 
active learning, or avoiding overfitting, or just knowing how much to trust the estimate of some 
scientifically meaningful quantity. In statistics, modeling uncertainty about parameters using a 
probability distribution (as opposed to just computing a point estimate) is known as inference. 

In this section, we use the posterior distribution to represent our uncertainty. This is the 
approach adopted in the field of Bayesian statistics. We give a brief introduction here, but more 
details can be found in the sequel to this book, [Mur23], as well as other good books, such as [Lam 18; 
Krul5; McE20; Gel+14; MKL21; MFR20]. 

To compute the posterior, we start with a prior distribution p(@), which reflects what we know 
before seeing the data. We then define a likelihood function p(D|@), which reflects the data we 
expect to see for each setting of the parameters. We then use Bayes rule to condition the prior on 
the observed data to compute the posterior p(@|D) as follows: 

p(9)p(D|6) p(9)p(D|9) 


PIP) = D Fp(8)p(D18") a0" ne 


The denominator p(D) is called the marginal likelihood, since it is computed by marginalizing 
over (or integrating out) the unknown @. This can be interpreted as the average probability of 
the data, where the average is wrt the prior. Note, however, that p(D) is a constant, independent of 
0, so we will often ignore it when we just want to infer the relative probabilities of 0 values. 

Equation (4.107) is analogous to the use of Bayes rule for COVID-19 testing in Section 2.3.1. The 
difference is that the unknowns correspond to parameters of a statistical model, rather than the 
unknown disease state of a patient. In addition, we usually condition on a set of observations D, as 
opposed to a single observation (such as a single test outcome). In particular, for a supervised or 
conditional model, the observed data has the form D = {(@n, Yn): n = 1: N}. For an unsupervised 
or unconditional model, the observed data has the form D = {(y,):n=1: N}. 

Once we have computed the posterior over the parameters, we can compute the posterior 
predictive distribution over outputs given inputs by marginalizing out the unknown parameters. 
In the supervised/ conditional case, this becomes 


plylæ, D) = I plylæ, @)p(6|D)a0 (4.108) 


This can be viewed as a form of Bayes model averaging (BMA), since we are making predictions 
using an infinite set of models (parameter values), each one weighted by how likely it is. The use of 
BMA reduces the chance of overfitting (Section 1.2.3), since we are not just using the single best 
model. 


4.6.1 Conjugate priors 


In this section, we consider a set of (prior, likelihood) pairs for which we can compute the posterior 
in closed form. In particular, we will use priors that are “conjugate” to the likelihood. We say that 
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a prior p(@) € F is a conjugate prior for a likelihood function p(D|@) if the posterior is in the 
same parameterized family as the prior, i.e., p(@|D) € F. In other words, F is closed under Bayesian 
updating. If the family F corresponds to the exponential family (defined in Section 3.4), then the 
computations can be performed in closed form. 

In the sections below, we give some common examples of this framework, which we will use later in 
the book. For simplicity, we focus on unconditional models (i.e., there are only outcomes or targets 
y, and no inputs or features x); we relax this assumption in Section 4.6.7. 


4.6.2 The beta-binomial model 


Suppose we toss a coin N times, and want to infer the probability of heads. Let yn = 1 denote the 
event that the n’th trial was heads, yn = 0 represent the event that the n’th trial was tails, and let 
D = {yn : n= 1: N} be all the data. We assume yn ~ Ber(@), where 6 € [0,1] is the rate parameter 
(probability of heads). In this section, we discuss how to compute p(0|D). 


4.6.2.1 Bernoulli likelihood 


We assume the data are iid or independent and identically distributed. Thus the likelihood 
has the form 


N 
PDO = |] 0a - 0)" = 0™1(1 — 6) (4.109) 


n=1 


where we have defined N; = ae I (yn = 1) and No = D I (yn = 0), representing the number of 
heads and tails. These counts are called the sufficient statistics of the data, since this is all we 


need to know about D to infer 0. The total count, N = No + Nj, is called the sample size. 


4.6.2.2 Binomial likelihood 


Note that we can also consider a Binomial likelihood model, in which we perform N trials and observe 
the number of heads, y, rather than observing a sequence of coin tosses. Now the likelihood has the 
following form: 


p(D|0) = Bin(y|N, 0) = C) o2 (1 — 0)N (4.110) 


The scaling factor (2) is independent of 0, so we can ignore it. Thus this likelihood is proportional 
to the Bernoulli likelihood in Equation (4.109), so our inferences about 0 will be the same for both 
models. 


4.6.2.3 Prior 


To simplify the computations, we will assume that the prior p(@) € F is a conjugate prior for the 
likelihood function p(y|@). This means that the posterior is in the same parameterized family as the 
prior, i.e., p(@|D) € F. 
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Figure 4.10: Updating a Beta prior with a Bernoulli likelihood with sufficient statistics Nı = 4, No = 1. (a) 
Beta(2,2) prior. (b) Uniform Beta(1,1) prior. Generated by beta_binom_ post_plot.ipynb. 


To ensure this property when using the Bernoulli (or Binomial) likelihood, we should use a prior 
of the following form: 


pO) x 0*71(1 — 6)?! x Beta(0| a, 5) (4.111) 
We recognize this as the pdf of a beta distribution (see Section 2.7.4). 


4.6.2.4 Posterior 


If we multiply the Bernoulli likelihood in Equation (4.109) with the beta prior in Equation (2.136) 
we get a beta posterior: 


p(O{D) x 0™ (1 — 6)Xo p14 — 9) 8-1 (4.112) 
x Beta(6| & +N1, 8 +No) (4.113) 
= Beta(0| @, 8) (4.114) 


where @2& +N; and BAB +No are the parameters of the posterior. Since the posterior has the same 
functional form as the prior, we say that the beta distribution is a conjugate prior for the Bernoulli 
likelihood. 

The parameters of the prior are called hyper-parameters. It is clear that (in this example) the 
hyper-parameters play a role analogous to the sufficient statistics; they are therefore often called 
pseudo counts. We see that we can compute the posterior by simply adding the observed counts 
(from the likelihood) to the pseudo counts (from the prior). 

The strength of the prior is controlled by Ñ=% + B; this is called the equivalent sample size, 
since it plays a role analogous to the observed sample size, N = No + Nj. 


4.6.2.5 Example 


For example, suppose we set a=B= 2. This is like saying we believe we have already seen two heads 
and two tails before we see the actual data; this is a very weak preference for the value of 6 = 0.5. 
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The effect of using this prior is illustrated in Figure 4.10a. We see the posterior (blue line) is a 
“compromise” between the prior (red line) and the likelihood (black line). 
If we set == 1, the corresponding prior becomes the uniform distribution: 


p(0) = Beta(6|1, 1) x 6° (1 — 6)° = Unif(6|0, 1) (4.115) 
The effect of using this prior is illustrated in Figure 4.10b. We see that the posterior has exactly the 
same shape as the likelihood, since the prior was “uninformative”. 
4.6.2.6 Posterior mode (MAP estimate) 


The most probable value of the parameter is given by the MAP estimate 


Omap = arg max p(6|D) (4.116) 
= arg max log p(6|D) (4.117) 
= arg max log p(8) + log p(D|6) (4.118) 


Using calculus, one can show that this is given by 


se a+N,-1 
map — D> = 1 (4.119) 
aN, -1+8 +Mo-1 
If we use a Beta(0|2, 2) prior, this amounts to add-one smoothing: 
A Nı +1 Nı +1 
Oin = = 4.120 
PN, +14+Not+1 N+2 ( ) 


If we use a uniform prior, p(0) « 1, the MAP estimate becomes the MLE, since log p(0) = 0: 


Omie = arg max log p(D|@) (4.121) 


When we use a Beta prior, the uniform distribution is 7=$= 1. In this case, the MAP estimate 
reduces to the MLE: 


Ni Ny 


—— = — 4.122 
Ni + No N ( ) 


Omle = 


If Nı = 0, we will estimate that p(Y = 1) = 0.0, which says that we do not predict any future 
observations to be 1. This is a very extreme estimate, that is likely due to insufficient data. We can 
solve this problem using a MAP estimate with a stronger prior, or using a fully Bayesian approach, 
in which we marginalize out 0 instead of estimating it, as explained in Section 4.6.2.9. 


4.6.2.7 Posterior mean 


The posterior mode can be a poor summary of the posterior, since it corresponds to a single point. 
The posterior mean is a more robust estimate, since it integrates over the whole space. 
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If p(6|D) = Beta(6| @, 2), then the posterior mean is given by 
a 
z [0D] = == = 
B+a 
where V¥=6 + â is the strength (equivalent sample size) of the posterior. 
We will now show that the posterior mean is a convex combination of the prior mean, m =@ / Ñ 
(where N“a + 8 is the prior strength), and the MLE: ĝmie = A, 


I> 


0 


(4.123) 


D| D 


a +N. N N Ñ N N A 
pjan e a i a 
& +Nı+ B+No N+ Ñ N+ Ñ N+NN 
where A = X is the ratio of the prior to posterior equivalent sample size. So the weaker the prior, 


the smaller is A, and hence the closer the posterior mean is to the MLE. 


4.6.2.8 Posterior variance 


To capture some notion of uncertainty in our estimate, a common approach is to compute the 
standard error of our estimate, which is just the posterior standard deviation: 


se(@) = y Y [0D] (4.125) 


In the case of the Bernoulli model, we showed that the posterior is a beta distribution. The variance 
of the beta posterior is given by 


vøn =$ — E -rpp —_8 (4.126) 
(@ + 8)2(@ + 6 +1) @(1+@+ 8) 
where @=a +N, and G=8 +No. If N >a + B, this simplifies to 
NiNo 61-86 
V [0D] x = = AY) (4.127) 


N3 N 
where Ê is the MLE. Hence the standard error is given by 


o = /VD] ~ ste) (4.128) 


We see that the uncertainty goes down at a rate of 1/V N. We also see that the uncertainty (variance) 
is maximized when 0 = 0.5, and is minimized when ð is close to 0 or 1. This makes sense, since it is 
easier to be sure that a coin is biased than to be sure that it is fair. 


4.6.2.9 Posterior predictive 


Suppose we want to predict future observations. A very common approach is to first compute an 
estimate of the parameters based on training data, ô(D), and then to plug that parameter back into 
the model and use p(y|@) to predict the future; this is called a plug-in approximation. However, 
this can result in overfitting. As an extreme example, suppose we have seen N = 3 heads in a row. 
The MLE is 6 = 3 /3 = 1. However, if we use this estimate, we would predict that tails are impossible. 

One solution to this is to compute a MAP estimate, and plug that in, as we discussed in Section 4.5.1. 
Here we discuss a fully Bayesian solution, in which we marginalize out 0. 
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Figure 4.11: Illustration of sequential Bayesian updating for the beta-Bernoulli model. Each colored box 
represents the predicted distribution p(a|ht), where ht = (Ni,t, No,t) is the sufficient statistic derived from 
history of observations up until time t, namely the total number of heads and tails. The probability of heads 
(blue bar) is given by p(xt = 1]he) = (Ne + 1)/(t +2), assuming we start with a uniform Beta(@|1,1) prior. 
From Figure 8 of [Ort+19]. Used with kind permission of Pedro Ortega. 


Bernoulli model 


For the Bernoulli model, the resulting posterior predictive distribution has the form 


ply = 1|D) = 1 p(y = 1)0)p(0|D) a6 (4.129) 


a 


a 


(4.130) 


1 
= | 6 Beta(6| @, 2)d0 = E [6|D] = ——~ 
f, @ Beta(0| a, Bdo =E f(D] = -~ 
In Section 4.5.1, we had to use the Beta(2,2) prior to recover add-one smoothing, which is a 
rather unnatural prior. In the Bayesian approach, we can get the same effect using a uniform prior, 
p(0) = Beta(6|1, 1), since the predictive distribution becomes 


Ni+1 
= 1|D) = ———_ 4.131 
ply =1P) = oN (4.131) 
This is known as Laplace’s rule of succession. See Figure 4.11 for an illustration of this in the 
sequential setting. 


Binomial model 


Now suppose we were interested in predicting the number of heads in M > 1 future coin tossing 
trials, i.e., we are using the binomial model instead of the Bernoulli model. The posterior over 6 is 
the same as before, but the posterior predictive distribution is different: 


1 
pD, m) = | Bin(y|M, 0)Beta(6| &, 3)d0 (4.132) 
M ol a _ g\M-yg@-1/1 _ p ßB-1 
= (*) nan I (1 — 0)@—¥9?-1(1 — 0)? -tdo (4.133) 
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Figure 4.12: (a) Posterior predictive distributions for 10 future trials after seeing Nı = 4 heads and No = 1 
tails. (b) Plug-in approximation based on the same data. In both cases, we use a uniform prior. Generated by 
beta_binom_post_pred_ plot. ipynb. 


We recognize the integral as the normalization constant for a Beta(@ +y, M — y+ B) distribution. 
Hence 


i š% a 
/ gute-ley _ g)M-u+8—lag = B(y+ a, M — y+ b) (4.134) 
0 


Thus we find that the posterior predictive is given by the following, known as the (compound) 
beta-binomial distribution: 


M\ B(x+ @,M — z+ B) 
2 B(@, 8) 
In Figure 4.12(a), we plot the posterior predictive density for M = 10 after seeing Nı = 4 heads 


and No = 1 tails, when using a uniform Beta(1,1) prior. In Figure 4.12(b), we plot the plug-in 
approximation, given by 


(4.135) 


Bb(2|M,@,8) ê ( 


p(6|D) ~ 6(8 — ô) (4.136) 


1 
p(y|D.M) = f Bin(ylM, #)p(0|D)ad = Bin(y|M, ô) (4.137) 

0 
where Ô is the MAP estimate. Looking at Figure 4.12, we see that the Bayesian prediction has 
longer tails, spreading its probability mass more widely, and is therefore less prone to overfitting and 
black-swan type paradoxes. (Note that we use a uniform prior in both cases, so the difference is not 


arising due to the use of a prior; rather, it is due to the fact that the Bayesian approach integrates 
out the unknown parameters when making its predictions.) 


4.6.2.10 Marginal likelihood 


The marginal likelihood or evidence for a model M is defined as 
PPM) = f POMPPI, M)d0 (4.138) 
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When performing inference for the parameters of a specific model, we can ignore this term, since it is 
constant wrt 0. However, this quantity plays a vital role when choosing between different models, 
as we discuss in Section 5.2.2. It is also useful for estimating the hyperparameters from data (an 
approach known as empirical Bayes), as we discuss in Section 4.6.5.3. 

In general, computing the marginal likelihood can be hard. However, in the case of the beta- 
Bernoulli model, the marginal likelihood is proportional to the ratio of the posterior normalizer to 
the prior normalizer. To see this, recall that the posterior for the beta-binomial models is given by 
p(6|D) = Beta(@|a’, b’), where a’ = a+ Ny and b' = b+ No. We know the normalization constant of 
the posterior is B(a’,b’). Hence 


(0|D) nD) (4.139) 
7 aD Em oa o=] (e gii pr (4.140) 
7 C) a a (4.141) 

So 
AET ~ i) wae (4.142) 
p(D) = i) a oe a No) (4.143) 


The marginal likelihood for the beta-Bernoulli model is the same as above, except it is missing the 
N 
( J term. 


4.6.2.11 Mixtures of conjugate priors 


The beta distribution is a conjugate prior for the binomial likelihood, which enables us to easily 
compute the posterior in closed form, as we have seen. However, this prior is rather restrictive. For 
example, suppose we want to predict the outcome of a coin toss at a casino, and we believe that the 
coin may be fair, but may equally likely be biased towards heads. This prior cannot be represented 
by a beta distribution. Fortunately, it can be represented as a mixture of beta distributions. 
For example, we might use 


p(0) = 0.5 Beta(0|20, 20) + 0.5 Beta(0|30, 10) (4.144) 


If 0 comes from the first distribution, the coin is fair, but if it comes from the second, it is biased 
towards heads. 

We can represent a mixture by introducing a latent indicator variable h, where h = k means that 
0 comes from mixture component k. The prior has the form 


p0) = X p(h = k)p(Olh = k) (4.145) 
k 
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mixture of Beta distributions 


Figure 4.18: A mixture of two Beta distributions. Generated by mixbetademo.ipynb. 


where each p(6|h = k) is conjugate, and p(h = k) are called the (prior) mixing weights. One can 
show (Exercise 4.6) that the posterior can also be written as a mixture of conjugate distributions as 
follows: 


p(8|D) = X` p(h = k|D)p(6|D, h = k) (4.146) 
k 


where p(h = k|D) are the posterior mixing weights given by 


p(h = k)p(D|h = k) 
pn KID) = p= Oh =F) ae" 


Here the quantity p(D|h = k) is the marginal likelihood for mixture component k (see Section 4.6.2.10). 
Returning to our example above, if we have the prior in Equation (4.144), and we observe N; = 20 
heads and No = 10 tails, then, using Equation (4.143), the posterior becomes 


p(0|D) = 0.346 Beta(0|40, 30) + 0.654 Beta(0|50, 20) (4.148) 


See Figure 4.13 for an illustration. 
We can compute the posterior probability that the coin is biased towards heads as follows: 


Pr(9 > 0.5|D) = X Pr(0 > 0.5|D, h = k)p(h = kD) = 0.9604 (4.149) 
k 


If we just used a single Beta(20,20) prior, we would get a slightly smaller value of Pr(@ > 0.5|D) = 
0.8858. So if we were “suspicious” initially that the casino might be using a biased coin, our fears 
would be confirmed more quickly than if we had to be convinced starting with an open mind. 


4.6.3 The Dirichlet-multinomial model 


In this section, we generalize the results from Section 4.6.2 from binary variables (e.g., coins) to 
K-ary variables (e.g., dice). 
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4.6.3.1 Likelihood 


Let Y ~ Cat(@) be a discrete random variable drawn from a categorical distribution. The likelihood 
has the form 


p(D|@) = I] Cat(yn|0) = Il [Jae = I AG (4.150) 


n=1c=1 


where Ne = >>, 1 (Yn =). 


4.6.3.2 Prior 


The conjugate prior for a categorical distribution is the Dirichlet distribution, which is a mul- 
tivariate generalization of the beta distribution. This has support over the probability simplex, 
defined by 


K 
Sx ={0:0<& <1,) 0r =1} (4.151) 
k=1 


The pdf of the Dirichlet is defined as follows: 


Dir(6| x) + —— aw To '1(0 € Sx) (4.152) 
k=1 


where B(&) is the multivariate beta function, 
B(@) & = t (4.153) 


Figure 4.14 shows some plots of the Dirichlet when K = 3. We see that Go= }°>, &p controls the 
strength of the distribution (how peaked it is), and the & control where the peak occurs. For example, 
Dir(1,1,1) is a uniform distribution, Dir(2, 2,2) is a broad distribution centered at (1/3, 1/3, 1/3), 
and Dir(20, 20, 20) is a narrow distribution centered at (1/3, 1/3, 1/3). Dir(3, 3, 20) is an asymmetric 
distribution that puts more density in one of the corners. If &&< 1 for all k, we get “spikes” at 
the corners of the simplex. Samples from the distribution when &@;< 1 will be sparse, as shown in 
Figure 4.15. 


4.6.3.3 Posterior 


We can combine the multinomial likelihood and Dirichlet prior to compute the posterior, as follows: 


p(O|D) x p(D|9)Dir(6| &) (4.154) 
= |J [0 p= (4.155) 
k k 
= Dir(6| & +M,...,0% +Nx) (4.156) 
= Dir(6| &) (4.157) 
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Figure 4.14: (a) The Dirichlet distribution when K = 3 defines a distribution over the simplex, which can be 
represented by the triangular surface. Points on this surface satisfy 0 < Ok < 1 and D 0k = 1. Generated 
by dirichlet_ 3d_triangle_plot.ipynb. (b) Plot of the Dirichlet density for &= (20,20,20). (c) Plot of the 
Dirichlet density for &= (3,3,20). (d) Plot of the Dirichlet density for &= (0.1,0.1,0.1). Generated by 
dirichlet_ 3d_spiky_plot.ipynb. 


where Q;,=a;, +N; are the parameters of the posterior. So we see that the posterior can be computed 
by adding the empirical counts to the prior counts. 
The posterior mean is given by 


aD 


ue Ak 


i: =—-— (4.158) 
Dera Oey 
The posterior mode, which corresponds to the MAP estimate, is given by 
a ay, —1 
ena aces (4.159) 


Erw —1) 
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Figure 4.15: Samples from a 5-dimensional symmetric Dirichlet distribution for different parameter values. 
(a) &= (0.1,...,0.1). This results in very sparse distributions, with many Os. (b) &= (1,...,1). This results 
in more uniform (and dense) distributions. Generated by dirichlet_samples_ plot.ipynb. 


If we use @,= 1, corresponding to a uniform prior, the MAP becomes the MLE: 


6, = Ng /N (4.160) 


(See Section 4.2.4 for a more direct derivation of this result.) 


4.6.3.4 Posterior predictive 


The posterior predictive distribution is given by 


ply = kD) = | ply = h@)p(6|D)a8 (4.161) 
ap 
= J %:0(0x1D)a0, = E[6,|D] = DET (4.162) 
In other words, the posterior predictive distribution is given by 
p(y|D) = Cat(y|@) (4.163) 


where 0 = E [0|D] are the posterior mean parameters. If instead we plug-in the MAP estimate, we 
will suffer from the zero-count problem. The only way to get the same effect as add-one smoothing is 
to use a MAP estimate with &.= 2. 

Equation (4.162) gives the probability of a single future event, conditioned on past observations 


y = (y1,---, yn). In some cases, we want to know the probability of observing a batch of future data, 
say y = (91,---,Y). We can compute this as follows: 
: POLY 
p(yly) = hy) (4.164) 
p(y) 


The denominator is the marginal likelihood of the training data, and the numerator is the marginal 
likelihood of the training and future test data. We discuss how to compute such marginal likelihoods 
in Section 4.6.3.5. 
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4.6.3.5 Marginal likelihood 


By the same reasoning as in Section 4.6.2.10, one can show that the marginal likelihood for the 
Dirichlet-categorical model is given by 


p(D) = Ae (4.165) 
where 
B(a) = Me- Plo) (4.166) 


7 XOD ak) 


Hence we can rewrite the above result in the following form, which is what is usually presented in 
the literature: 


(4.167) 


p(D) = 5 TÈ gak) Ul T(Ne + a) 


(N +355 Ak T'(ax) 


4.6.4 The Gaussian-Gaussian model 


In this section, we derive the posterior for the parameters of a Gaussian distribution. For simplicity, 
we assume the variance is known. (The general case is discussed in the sequel to this book, [Mur23], 
as well as other standard references on Bayesian statistics.) 


4.6.4.1 Univariate case 


If ø? is a known constant, the likelihood for u has the form 


1 x 
p(D|u) x exp (- X (n = n”) (4.168) 


One can show that the conjugate prior is another Gaussian, N (u| M, ¥?). Applying Bayes’ rule for 
Gaussians, as in Section 4.6.4.1, we find that the corresponding posterior is given by 


P(u|D, 0?) = N (u| M, 7?) (4.169) 
1 oT 
m2 
Sty N77 +0? ( ) 
= 2 x2 
~ nfm  Ny\ _ o = NT _ 
m = (= T g? ) = N 72 +g? N F2 gee (4.171) 


where yê $ T Yn is the empirical mean. 
This result is easier to understand if we work in terms of the precision parameters, which are 
just inverse variances. Specifically, let « = 1/g? be the observation precision, and \= 1/ 7? be the 
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Figure 4.16: Inferring the mean of a univariate Gaussian with known o° given observation y = 3. (a) Using 
strong prior, p(u) = N (u|0,1). (b) Using weak prior, p(w) = N (u|0,5). Generated by gauss infer_ 1d.ipynb. 


precision of the prior. We can then rewrite the posterior as follows: 


P(uID, K) =N (u| M, A”) (4.172) 
X=A +NK (4.173) 

a NKy+ Xm Nk _ X 
=r = = = 4.174 
we N Na Nea (4.174) 


These equations are quite intuitive: the posterior precision Ñ is the prior precision \ plus N units of 
measurement precision «<. Also, the posterior mean M is a convex combination of the empirical mean 
y and the prior mean Mm. This makes it clear that the posterior mean is a compromise between the 
empirical mean and the prior. If the prior is weak relative to the signal strength (} is small relative 
to «), we put more weight on the empirical mean. If the prior is strong relative to the signal strength 
(À is large relative to x), we put more weight on the prior. This is illustrated in Figure 4.16. Note 
also that the posterior mean is written in terms of NKJ, so having N measurements each of precision 
k is like having one measurement with value y and precision Nx. 


Posterior after seeing N = 1 examples 


To gain further insight into these equations, consider the posterior after seeing a single data point y 
(so N = 1). Then the posterior mean can be written in the following equivalent ways: 


a fore. 
i += 4.175 
M * mM xy ( ) 
=m +0- m) (4.176) 
Mijn. 
=y- NU m) (4.177) 


The first equation is a convex combination of the prior mean and the data. The second equation 
is the prior mean adjusted towards the data y. The third equation is the data adjusted towards 
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the prior mean; this is called a shrinkage estimate. This is easier to see if we define the weight 
w = 4/3, which is the ratio of the prior to posterior precision. Then we have 


M= y — w(y— M) =(l-w)ytwm (4.178) 

Note that, for a Gaussian, the posterior mean and posterior mode are the same. Thus we can use 
the above equations to perform MAP estimation. See Exercise 4.2 for a simple example. 
Posterior variance 


In addition to the posterior mean or mode of u, we might be interested in the posterior variance, 
which gives us a measure of confidence in our estimate. The square root of this is called the standard 
error of the mean: 


se(u) & VHD] (4.179) 


Suppose we use an uninformative prior for u by setting X= 0 (see Section 4.6.5.1). In this case, the 
posterior mean is equal to the MLE, m= J. Suppose, in addition, that we approximate o? by the 
sample variance 


N 
1 
2A —\2 
sS = N Dn y) (4.180) 


Hence \= N& = N/s?, so the SEM becomes 


se(u) = VW UID] = Se = (4.181) 


Thus we see that the uncertainty in u is reduced at a rate of 1//N. 
In addition, we can use the fact that 95% of a Gaussian distribution is contained within 2 standard 
deviations of the mean to approximate the 95% credible interval for u using 


I95(u|D) = 9 + 2—— (4.182) 


4.6.4.2 Multivariate case 


For D-dimensional data, the likelihood has the following form (where we drop terms that are 
independent of p): 


N 
P(P\w) = | [Monia 5) (4.183) 
7 1 N t< 
= (aap) -LTE (4.181 
el 
xN (Ulu, F) (4.185) 
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Figure 4.17: Illustration of Bayesian inference for the mean of a 2d Gaussian. (a) The data is generated 


from yn ~ N(u, =), where u = [0.5,0.5]" and © = 0.1[2,1;1,1]). (b) The prior is p(w) = N(p|0, 0.112). (c) 
We show the posterior after 10 data points have been observed. Generated by gauss infer 2d.ipynb. 


where Y = + ee Yn. (The proof of the last equation is given right after Equation (3.65).) Thus 
we replace the set of observations with their mean, and scale down the covariance by a factor of N. 
For simplicity, we will use a conjugate prior, which in this case is a Gaussian, namely 


p(n) =N (u| mñ, V) (4.186) 


We can derive a Gaussian posterior for u based on the results in Section 3.3.1. We get 


pP(w|D, ©) = N (u| M, V) (4.187) 
gay No! (4.188) 
M =F (S (Ng)+ Y M) (4.189) 


Figure 4.17 gives a 2d example of these results. 


4.6.5 Beyond conjugate priors 


We have seen various examples of conjugate priors, all of which have come from the exponential 
family (see Section 3.4). These priors have the advantage of being easy to interpret (in terms of 
sufficient statistics from a virtual prior dataset), and easy to compute with. However, for most 
models, there is no prior in the exponential family that is conjugate to the likelihood. Furthermore, 
even where there is a conjugate prior, the assumption of conjugacy may be too limiting. Therefore in 
the sections below, we briefly discuss various other kinds of priors. 


4.6.5.1 Noninformative priors 


When we have little or no domain specific knowledge, it is desirable to use an uninformative, 
noninformative or objective priors, to “let the data speak for itself”. For example, if we want to 
infer a real valued quantity, such as a location parameter u € R, we can use a flat prior p(u) x 1. 
This can be viewed as an “infinitely wide” Gaussian. 
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Unfortunately, there is no unique way to define uninformative priors, and they all encode some 
kind of knowledge. It is therefore better to use the term diffuse prior, minimally informative 
prior or default prior. See the sequel to this book, [Mur23], for more details. 


4.6.5.2 Hierarchical priors 


Bayesian models require specifying a prior p(@) for the parameters. The parameters of the prior are 
called hyperparameters, and will be denoted by €. If these are unknown, we can put a prior on 
them; this defines a hierarchical Bayesian model, or multi-level model, which can visualize 
like this: € + 0 — D. We assume the prior on the hyper-parameters is fixed (e.g., we may use some 
kind of minimally informative prior), so the joint distribution has the form 


p(§,9,D) = p(€)p(|€)p(D|9) (4.190) 


The hope is that we can learn the hyperparameters by treating the parameters themselves as 
datapoints. This is useful when we have multiple related parameters that need to be estimated (e.g., 
from different subpopulations, or muliple tasks); this provides a learning signal to the top level of the 
model. See the sequel to this book, [Mur23], for details. 


4.6.5.3 Empirical priors 


In Section 4.6.5.2, we discussed hierarchical Bayes as a way to infer parameters from data. Unfortu- 
nately, posterior inference in such models can be computationally challenging. In this section, we 
discuss a computationally convenient approximation, in which we first compute a point estimate of 
the hyperparameters, È, and then compute the conditional posterior, p(Olé , D), rather than the joint 
posterior, p(@,€|D). 

To estimate the hyper-parameters, we can maximize the marginal likelihood: 


Emi (D) = argmax p(DIg) = argmox J o(ojo)v(ole)a0 (4.191) 


This technique is known as type II maximum likelihood, since we are optimizing the hyperparam- 
eters, rather than the parameters. Once we have estimated Ê, we compute the posterior D(Olé, D) 
in the usual way. 

Since we are estimating the prior parameters from data, this approach is empirical Bayes (EB) 
[CL96]. This violates the principle that the prior should be chosen independently of the data. 
However, we can view it as a computationally cheap approximation to inference in the full hierarchical 
Bayesian model, just as we viewed MAP estimation as an approximation to inference in the one level 
model 0 > D. In fact, we can construct a hierarchy in which the more integrals one performs, the 
“more Bayesian” one becomes, as shown below. 


Method Definition 

Maximum likelihood 6 = argmaxg p(D|@) 

MAP estimation 6(€) = argmaxg p(D|@)p(O\g) 
ML-II (Empirical Bayes) € = argmaxg | p(D|0)p(6|€)d0 
MAP-II € = argmax, f p(D|0)p(9|€)p(€)de 
Full Bayes p(0, €|D) x p(D|6)p(9|€)p() 
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Figure 4.18: (a) Central interval and (b) HPD region for a Beta(3,9) posterior. The CI is (0.06, 0.52) and 
the HPD is (0.04, 0.48). Adapted from Figure 3.6 of [Hof09]. Generated by betaHPD.ipynb. 


Note that ML-II is less likely to overfit than “regular” maximum likelihood, because there are 
typically fewer hyper-parameters € than there are parameters 8. See the sequel to this book, [Mur23], 
for details. 


4.6.6 Credible intervals 


A posterior distribution is (usually) a high dimensional object that is hard to visualize and work 
with. A common way to summarize such a distribution is to compute a point estimate, such as the 
posterior mean or mode, and then to compute a credible interval, which quantifies the uncertainty 
associated with that estimate. (A credible interval is not the same as a confidence interval, which is 
a concept from frequentist statistics which we discuss in Section 4.7.4.) 

More precisely, we define a 100(1 — a)% credible interval to be a (contiguous) region C = (£, u) 
(standing for lower and upper) which contains 1 — a of the posterior probability mass, i.e., 


CalD) = (€,u): P< d<ulD)=1-a (4.192) 


There may be many intervals that satisfy Equation (4.192), so we usually choose one such that there 
is (1—a)/2 mass in each tail; this is called a central interval. If the posterior has a known functional 
form, we can compute the posterior central interval using 0 = F~!(a/2) and u = F~'(1—a/2), where 
F is the cdf of the posterior, and F~! is the inverse cdf. For example, if the posterior is Gaussian, 
p(O|D) = N (0,1), and a = 0.05, then we have £ = ®~!(a/2) = —1.96, and u = ®-1(1 — a/2) = 1.96, 
where ® denotes the cdf of the Gaussian. This is illustrated in Figure 2.2b. This justifies the common 
practice of quoting a credible interval in the form of u + 20, where u represents the posterior mean, 
c represents the posterior standard deviation, and 2 is a good approximation to 1.96. 

In general, it is often hard to compute the inverse cdf of the posterior. In this case, a simple 
alternative is to draw samples from the posterior, and then to use a Monte Carlo approximation to the 
posterior quantiles: we simply sort the S samples, and find the one that occurs at location a/S along 
the sorted list. As S — oo, this converges to the true quantile. See beta _ credible int _demo.ipynb 
for a demo of this. 

A problem with central intervals is that there might be points outside the central interval which 
have higher probability than points that are inside, as illustrated in Figure 4.18(a). This motivates 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


4.6. Bayesian statistics * 147 


pMIN 


a/2 al2 


(a) (b) 


Figure 4.19: (a) Central interval and (b) HPD region for a hypothetical multimodal posterior. Adapted from 
Figure 2.2 of [Gel+04]. Generated by postDensityIntervals.ipynb. 


an alternative quantity known as the highest posterior density or HPD region, which is the set 
of points which have a probability above some threshold. More precisely we find the threshold p* on 
the pdf such that 


l-a= f p(6|D) dé (4.193) 
6:p(6|D)>p* 


and then define the HPD as 
Ca(D) = {0 : p(O|D) = p*} (4.194) 


In 1d, the HPD region is sometimes called a highest density interval or HDI. For example, 
Figure 4.18(b) shows the 95% HDI of a Beta(3,9) distribution, which is (0.04,0.48). We see that 
this is narrower than the central interval, even though it still contains 95% of the mass; furthermore, 
every point inside of it has higher density than every point outside of it. 

For a unimodal distribution, the HDI will be the narrowest interval around the mode containing 
95% of the mass. To see this, imagine “water filling” in reverse, where we lower the level until 95% of 
the mass is revealed, and only 5% is submerged. This gives a simple algorithm for computing HDIs 
in the 1d case: simply search over points such that the interval contains 95% of the mass and has 
minimal width. This can be done by 1d numerical optimization if we know the inverse CDF of the 
distribution, or by search over the sorted data points if we have a bag of samples (see betaHPD.ipynb 
for some code). 

If the posterior is multimodal, the HDI may not even be a connected region: see Figure 4.19(b) for 
an example. However, summarizing multimodal posteriors is always difficult. 


4.6.7 Bayesian machine learning 


So far, we have focused on unconditional models of the form p(y|@). In supervised machine learning, 
we use conditional models of the form p(y|x,@). The posterior over the parameters is now p(6|D), 
where D = {(@n, Yn) :n = 1: N}. Computing this posterior can be done using the principles we 
have already discussed. This approach is called Bayesian machine learning, since we are “being 
Bayesian” about the model parameters. 
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4.6.7.1 Plugin approximation 


Once we have computed the posterior over the parameters, we can compute the posterior predictive 
distribution over outputs given inputs by marginalizing out the unknown parameters: 


pyle, D) = | pyle, 8)p(0\D)a9 (4.195) 


Of course, computing this integral is often intractable. A very simple approximation is to assume 
there is just a single best model, 0, such as the MLE. This is equivalent to approximating the 
posterior as an infinitely narrow, but infinitely tall, “spike” at the chosen value. We can write this as 
follows: 


p(@|D) = 6(6 — ô) (4.196) 


where ô is the Dirac delta function (see Section 2.6.5). If we use this approximation, then the 
predictive distribution can be obtained by simply “plugging in” the point estimate into the likelihood: 


pyle, D) = | p(yle,6)r(0\D)49~ | pluje, 6)0(0 — 8)40 = pluie, ô) (4.197) 


This follows from the sifting property of delta functions (Equation (2.129)). 

The approach in Equation (4.197) is called a plug-in approximation. This approach is equivalent 
to the standard approach used in most of machine learning, in which we first fit the model (i.e. 
compute a point estimate 6) and then use it to make predicitons. However, the standard (plug-in) 
approach can suffer from overfitting and overconfidence, as we discussed in Section 1.2.3. The 
fully Bayesian approach avoids this by marginalizing out the parameters, but can be expensive. 
Fortunately, even simple approximations, in which we average over a few plausible parameter values, 
can improve performance. We give some examples of this below. 


4.6.7.2 Example: scalar input, binary output 


Suppose we want to perform binary classification, so y € {0,1}. We will use a model of the form 


p(y|a; 0) = Ber(y|o(w'a + b)) (4.198) 
where 
a en 
2 4.1 
o(a) & (4.199) 


is the sigmoid or logistic function which maps R —> (0, 1], and Ber(y|u) is the Bernoulli distribution 
with mean yu (see Section 2.4 for details). In other words, 


_ 1 
~ 1+ e7(wie+) 


p(y = 1a; 0) = o(w' æ + b) (4.200) 
This model is called logistic regression. (We discuss this in more detail in Chapter 10.) 
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Figure 4.20: (a) Logistic regression for classifying if an Iris flower is Versicolor (y = 1) or setosa (y = 0) 
using a single input feature x corresponding to sepal length. Labeled points have been (vertically) jittered 
to avoid overlapping too much. Vertical line is the decision boundary. Generated by logreg_iris_ 1d.ipynb. 
(b) Same as (a) but showing posterior distribution. Adapted from Figure 4.4 of [Mar18]. Generated by 
logreg_iris_ bayes_ 1d_pymc3.ipynb. 


Let us apply this model to the task of determining if an iris flower is of type Setosa or Versicolor, 
Yn E {0,1}, given information about the sepal length, £n. (See Section 1.2.1.1 for a description of 
the iris dataset.) 

We first fit a 1d logistic regression model of the following form 


ply = 1x; 0) = a(b + wa) (4.201) 


to the dataset D = {(@n,Yn)} using maximum likelihood estimation. (See Section 10.2.3 for details 
on how to compute the MLE for this model.) Figure 4.20a shows the plugin approximation to the 
posterior predictive, p(y = l1|z, 6), where @ is the MLE of the parameters. We see that we become 
more confident that the flower is of type Versicolor as the sepal length gets larger, as represented by 
the sigmoidal (S-shaped) logistic function. 

The decision boundary is defined to be the input value z* where p(y = 1|x*; 6) = 0.5. We can 
solve for this value as follows: 


1 1 


a(b + wa") = eee os a (4.202) 
b+wa* =0 (4.203) 

b 
*=—— 4.204 
s=- (4.204) 


From Figure 4.20a, we see that x* ~ 5.5 cm. 

However, the above approach does not model the uncertainty in our estimate of the parameters, and 
therefore ignores the induced uncertainty in the output probabilities, and the location of the decision 
boundary. To capture this additional uncertainty, we can use a Bayesian approach to approximate 
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Time 
Figure 4.21: Distribution of arrival times for two different shipping companies. ETA is the expected time of 


arrival. A’s distribution has greater uncertainty, and may be too risky. From https: // bit. ly/39bc42L. 
Used with kind permission of Brendan Hasz. 


the posterior p(@|D). (See Section 10.5 for details.) Given this, we can approximate the posterior 
predictive distribution using a Monte Carlo approximation: 


ply = 1z, D) = 


Ul = 


S 
X ply = 12, 0°) (4.205) 


where 0° ~ p(8@|D) is a posterior sample. Figure 4.20b plots the mean and 95% credible interval of 
this function. We see that there is now a range of predicted probabilities for each input. We can 
also compute a distribution over the location of the decision boundary by using the Monte Carlo 
approximation 


p(a*|D) © <> 5 (e : >) (4.206) 


=l 


where (b°, w*) = 0°. The 95% credible interval for this distribution is shown by the “fat” vertical line 
in Figure 4.20b. 

Although carefully modeling our uncertainty may not matter for this application, it can be 
important in risk-sensitive applications, such as health care and finance, as we discuss in Chapter 5. 


4.6.7.3 Example: binary input, scalar output 


Now suppose we want to predict the delivery time for a package, y € R, if shipped by company A vs 
B. We can encode the company id using a binary feature x € {0,1}, where x = 0 means company A 
and x = 1 means company B. We will use the following discriminative model for this problem: 


p(y|z, 0) = N (Ylha, 02) (4.207) 
where N (y|u, 07) is the Gaussian distribution 


1 
V 2ra? 
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and 0 = (u0, H1, 00,01) are the parameters of the model. We can fit this model using maximum 
likelihood estimation as we discuss in Section 4.2.5; alternatively, we can adopt a Bayesian approach, 
as we discuss in Section 4.6.4. 

The advantage of the Bayesian approach is that by capturing uncertainty in the parameters 6, we 
also capture uncertainty in our forecasts p(y|x, D), whereas using a plug-in approximation p(y|z, 6) 
would underestimate this uncertainty. For example, suppose we have only used each company once, so 
our training set has the form D = {(#1 = 0, y1 = 15), (z2 = 1, y2 = 20)}. As we show in Section 4.2.5, 
the MLE for the means will be the empirical means, fig = 15 and fi; = 20, but the MLE for the 
standard deviations will be zero, 69 = ô = 0, since we only have a single sample from each “class”. 
The resulting plug-in prediction will therefore not capture any uncertainty. 

To see why modeling the uncertainty is important, consider Figure 4.21. We see that the expected 
time of arrival (ETA) for company A is less than for company B; however, the variance of A’s 
distribution is larger, which makes it a risky choice if you want to be confident the package will 
arrive by the specified deadline. (For more details on how to choose optimal actions in the presence 
of uncertainty, see Chapter 5.) 

Of course, the above example is extreme, because we assumed we only had one example from each 
delivery company. However, this kind of problem occurs whenever we have few examples of a given 
kind of input, as can happen whenever the data has a long tail of novel patterns, such as a new 
combination of words or categorical features. 


4.6.7.4 Scaling up 


The above examples were both extremely simple, involving 1d input and 1d output, and just 2—4 
parameters. Most practical problems involve high dimensional inputs, and sometimes high dimensional 
outputs, and therefore use models with lots of parameters. Unfortunately, computing the posterior, 
p(@|D), and the posterior predictive, p(y|x, D), can be computationally challenging in such cases. 
We discuss this issue in Section 4.6.8. 


4.6.8 Computational issues 


Given a likelihood p(D|@) and a prior p(@), we can compute the posterior p(0|D) using Bayes’ rule. 
However, actually performing this computation is usually intractable, except for simple special cases, 
such as conjugate models (Section 4.6.1), or models where all the latent variables come from a small 
finite set of possible values. We therefore need to approximate the posterior. There are a large variety 
of methods for performing approximate posterior inference, which trade off accuracy, simplicity, 
and speed. We briefly discuss some of these algorithms below, but go into more detail in the sequel 
to this book, [Mur23]. (See also [MFR20] for a review of various approximate inference methods, 
starting with Bayes’ original method in 1763.) 

As a running example, we will use the problem of approximating the posterior of a beta-Bernoulli 
model. Specifically, the goal is to approximate 


N 
p(O|D) x TL Bint) Beta(6|1, 1) (4.209) 


n=1 


where D consists of 10 heads and 1 tail (so the total number of observations is N = 11), and we 
use a uniform prior. Although we can compute this posterior exactly (see Figure 4.22), using the 
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True Posterior 4 == : Laplace approximation 


== : Grid approximation True posterior 


(a) (b) 


Figure 4.22: Approximating the posterior of a beta-Bernoulli model. (a) Grid approximation using 20 grid 
points. (b) Laplace approximation. Generated by laplace_approx_beta_binom_jax.ipynb. 


method discussed in Section 4.6.2, this serves as a useful pedagogical example since we can compare 
the approximation to the exact answer. Also, since the target distribution is just 1d, it is easy to 
visualize the results. (Note, however, that the problem is not completely trivial, since the posterior is 
highly skewed, due to the use of an imbalanced sample of 10 heads and 1 tail.) 


4.6.8.1 Grid approximation 


The simplest approach to approximate posterior inference is to partition the space of possible values 
for the unknowns into a finite set of possibilities, call them 01,...,0%, and then to approximate the 
posterior by brute-force enumeration, as follows: 


P(D|9x)p(Ox) _ _p(D|@x)p(9x) 
p(D) E pD, On") 

This is called a grid approximation. In Figure 4.22a, we illustrate this method applied to our 1d 

problem. We see that it is easily able to capture the skewed posterior. Unfortunately, this approach 


does not scale to problems in more than 2 or 3 dimensions, because the number of grid points grows 
exponentially with the number of dimensions. 


p(O = OL| D) = 


(4.210) 


4.6.8.2 Quadratic (Laplace) approximation 


In this section, we discuss a simple way to approximate the posterior using a multivariate Gaussian; 
this is known as a Laplace approximation or a quadratic approximation (see e.g., [TK86; 
RMCO9J). 

To derive this, suppose we write the posterior as follows: 


1 
p(O|D) = se *) 4.211 
Z 


where €(@) = — log p(0, D) is called an energy function, and Z = p(D) is the normalization constant. 
Performing a Taylor series expansion around the mode 0 (i.e., the lowest energy state) we get 


E(0) = E(Ô) + (0 — Ô)'g + 5(0 — 6)"H(0 — 6) (4.212) 
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where g is the gradient at the mode, and H is the Hessian. Since Ô is the mode, the gradient term is 
zero. Hence 


P(O, D) = e7E®) exp -200 — ô\TH(0 — 6) (4.213) 
POID) = Z0, D) = N(01ô, H=’) (4.214) 
Z = EÔ) (2r) P/H? (4.215) 


The last line follows from normalization constant of the multivariate Gaussian. 

The Laplace approximation is easy to apply, since we can leverage existing optimization algorithms 
to compute the MAP estimate, and then we just have to compute the Hessian at the mode. (In high 
dimensional spaces, we can use a diagonal approximation.) 

In Figure 4.22b, we illustrate this method applied to our 1d problem. Unfortunately we see that it 
is not a particularly good approximation. This is because the posterior is skewed, whereas a Gaussian 
is symmetric. In addition, the parameter of interest lies in the constrained interval 8 € [0,1], whereas 
the Gaussian assumes an unconstrained space, O € R. Fortunately, we can solve this latter problem 
by using a change of variable. For example, in this case we can apply the Laplace approximation to 
a = logit(@). This is a common trick to simplify the job of inference. 


4.6.8.3 Variational approximation 


In Section 4.6.8.2, we discussed the Laplace approximation, which uses an optimization procedure 
to find the MAP estimate, and then approximates the curvature of the posterior at that point 
based on the Hessian. In this section, we discuss variational inference (VI), which is another 
optimization-based approach to posterior inference, but which has much more modeling flexibility 
(and thus can give a much more accurate approximation). 

VI attempts to approximate an intractable probability distribution, such as p(@|D), with one that 
is tractable, g(@), so as to minimize some discrepancy D between the distributions: 


q* = argmin D(q, p) (4.216) 
qEQ 


where Q is some tractable family of distributions (e.g., multivariate Gaussian). If we define D to be 
the KL divergence (see Section 6.2), then we can derive a lower bound to the log marginal likelihood; 
this quantity is known as the evidence lower bound or ELBO. By maximizing the ELBO, we can 
improve the quality of the posterior approximation. See the sequel to this book, [Mur23], for details. 


4.6.8.4 Markov Chain Monte Carlo (MCMC) approximation 


Although VI is a fast, optimization-based method, it can give a biased approximation to the posterior, 
since it is restricted to a specific function form q € Q. A more flexible approach is to use a non- 
parametric approximation in terms of a set of samples, q(0) ~ DD 6(@ — 0°). This is called a 
Monte Carlo approximation to the posterior. The key issue is how to create the posterior samples 
0° ~ p(@|D) efficiently, without having to evaluate the normalization constant p(D) = f p(@,D)dé. 
A common approach to this problem is known as Markov chain Monte Carlo or MCMC. If 


we augment this algorithm with gradient-based information, derived from V log p(0, D), we can 
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significantly speed up the method; this is called Hamiltonian Monte Carlo or HMC. See the 
sequel to this book, [Mur23], for details. 


4.7 Frequentist statistics * 


The approach to statistical inference that we described in Section 4.6 is called Bayesian statistics. 
It treats parameters of models just like any other unknown random variable, and applies the rules 
of probability theory to infer them from data. Attempts have been made to devise approaches to 
statistical inference that avoid treating parameters like random variables, and which thus avoid 
the use of priors and Bayes rule. This alternative approach is known as frequentist statistics, 
classical statistics or orthodox statistics. 

The basic idea (formalized in Section 4.7.1) is to represent uncertainty by calculating how a quantity 
estimated from data (such as a parameter or a predicted label) would change if the data were changed. 
It is this notion of variation across repeated trials that forms the basis for modeling uncertainty 
used by the frequentist approach. By contrast, the Bayesian approach views probability in terms 
of information rather than repeated trials. This allows the Bayesian to compute the probability of 
one-off events, as we discussed in Section 2.1.1. Perhaps more importantly, the Bayesian approach 
avoids certain paradoxes that plague the frequentist approach (see Section 4.7.5 and Section 5.5.4). 
These pathologies led the famous statistician George Box to say: 


I believe that it would be very difficult to persuade an intelligent person that current [frequentist] 
statistical practice was sensible, but that there would be much less difficulty with an approach 
via likelihood and Bayes’ theorem. — George Box, 1962 (quoted in [Jay76]). 


Nevertheless, it is useful to be familiar with frequentist statistics, since it is widely used, and has 
some key concepts that are useful even for Bayesians [Rub84]. 


4.7.1 Sampling distributions 


In frequentist statistics, uncertainty is not represented by the posterior distribution of a random 
variable, but instead by the sampling distribution of an estimator. 

The term “estimator” is defined in the section on decision theory in Section 5.1, but in brief, an 
estimator 6 : D — A is a decision procedure that specifies what action to take given some observed 
data. The action could be to predict a class label, or the next observation, or to predict the unknown 
parameters. In the latter case, the estimator is often denoted by Ô, but this notation is ambiguous, 
since it looks like it represents a parameter vector rather than a function. So instead we will use 
the notation Ô. This function could compute the MLE, or the method of moments estimate, etc. 
The output of this function, when applied to a specific dataset of size N, is denote 6= O(D), where 
D = {a1,..., an}. 

The key idea in frequentist statistics is to view the data D as a random variable, and the parameters 
from which the data are drawn, 6*, as a fixed but unknown constant. Thus Ê = O(D) is a random 
variable, and its distribution is known as the sampling distribution of the estimator. To understand 
what thus means, suppose we create S different datasets, each of the form 


DE) = {£n ~ p(an|O*) :n=1: N} (4.217) 
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We denote this by DCS) ~ @* for brevity. Now we apply the estimator to each D“) to get a set 
of estimates, {@(D“))}. As we let S — oo, the distribution induced by this set is the sampling 
distribution of the estimator. More precisely, we have 


SamplingDist(O, @*) = PushThrough(p(D|6*), Ô) (4.218) 


where we push the data distribution through the estimator function to induce a distribution of 
estimates. In some cases, we can compute the sampling distribution analytically, as we discuss 
in Section 4.7.2, although typically we need to approximate it by Monte Carlo, as we discuss in 
Section 4.7.3. 


4.7.2 Gaussian approximation of the sampling distribution of the MLE 


The most common estimator is the MLE. When the sample size becomes large, the sampling 
distribution of the MLE for certain models becomes Gaussian. This is known as the asymptotic 
normality of the sampling distribution. More formally, we have the following result: 


Theorem 4.7.1. If the parameters are identifiable, then 
SamplingDist(O™'°, @*) > N(-|@*, (NF(6*))~!) (4.219) 
where F(0*) is the Fisher information matrix, defined in Equation (4.220). 


Equivalently, the above result says that the distribution of ,/ NF (0*)(0 — 0*) approaches M (0, I), 
where 6 = Ôm! (D), 

The Fisher information matrix measures the amount of curvature of the log-likelihood surface at 
its peak, as we show below. More formally, the Fisher information matrix (FIM) is defined to 
be the covariance of the gradient of the log likelihood (also called the score function): 


F(0) = Ez~p(z|0) [V log p(x|6)V log p(x|@)"] (4.220) 


Hence the (i, j) th entry has the form 


Fy = Bove |( 2- 1ogp(zl0)) (5 toer(el0))| (4.221) 


One can show the following result. 


Theorem 4.7.2. If log p(x|0@) is twice differentiable, and under certain regularity conditions, the 
FIM is equal to the expected Hessian of the NLL, i.e., 
2 


o 
Fj; = —Egve soz 


log p(æ10)| (4.222) 

Thus we can interpret the FIM as the Hessian of the NLL. This helps us understand the result 
in Equation (4.219): a log-likelihood function with high curvature (large Hessian) will result in a 
low variance estimate, since the parameters are “well determined” by the data, and hence robust to 
repeated sampling. 
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Figure 4.23: Bootstrap (top row) vs Bayes (bottom row). The N data cases were generated from Ber(0 = 0.7). 
Left column: N = 10. Right column: N = 100. (a-b) A bootstrap approximation to the sampling distribution 
of the MLE for a Bernoulli distribution. We show the histogram derived from B = 10,000 bootstrap samples. 
(c-d) Histogram of 10,000 samples from the posterior distribution using a uniform prior. Generated by 
bootstrap demo_ bernoulli.ipynb. 


In the scalar case, we have that V [ô — o| > NFO): The square root of the variance of the 
sampling distribution of an estimator is known as its standard error or se. Hence we can say that 
the distribution of aor approaches M (0,1). In practice the se is not known, but it can be estimated 
from data. For example, suppose X,, ~ Ber(6*) and let ĝ = + aan Xn be the MLE. The standard 


error is se = ,/V p = ,/6*(1 — 6*)/N, so the estimated standard error is se = 4/Ô(1 — 6)/N. 


4.7.3 Bootstrap approximation of the sampling distribution of any estimator 


In cases where the estimator is a complex function of the data (e.g., not just an MLE), or when the 
sample size is small, we can approximate its sampling distribution using a Monte Carlo technique 
known as the bootstrap. 

The idea is simple. If we knew the true parameters 0*, we could generate many (say S) fake 
datasets, each of size N, from the true distribution, using D) = {æn ~ p(#,|0*) :n=1: N}. We 
could then compute our estimate from each sample, 0° = O(D‘)) and use the empirical distribution 
of the resulting 0° as our estimate of the sampling distribution, as in Equation (4.218). Since 0* is 
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unknown, we can use the dataset itself as an empirical approximation to p(æn|0*). More precisely, 
the idea is to generate each D‘) by sampling N data points with replacement from the original 
dataset.° 

Figure 4.23(a-b) shows an example where we compute the sampling distribution of the MLE for a 
Bernoulli using the bootstrap. When N = 10, we see that the sampling distribution is asymmetric, 
and therefore quite far from Gaussian, but when N = 100, the distribution looks more Gaussian, as 
theory suggests (see Section 4.7.2). 

The number of unique data points in a bootstrap sample is just 0.632 x N, on average. (To see 
this, note that the probability an item is picked at least once is (1 — (1 — 1/N)), which approaches 
1—e ! & 0.632 for large N.) However, there are more sophisticated versions of bootstrap that 
improve on this (see e.g., [Efr87; EH16]). 


4.7.3.1 Bootstrap is a “poor man’s” posterior 


A natural question is: what is the connection between the parameter estimates ô: = (D) computed 
by the bootstrap and parameter values sampled from the posterior, 0° ~ p(-|D)? Conceptually they 
are quite different. But in the common case that the estimator is MLE and the prior is not very 
strong, they can be quite similar. For example, Figure 4.23(c-d) shows an example where we compute 
the posterior using a uniform Beta(1,1) prior, and then sample from it. We see that the posterior 
and the sampling distribution are quite similar. So one can think of the bootstrap distribution as a 
“poor man’s” posterior [HTFO1, p235]. 

However, perhaps surprisingly, bootstrap can be slower than posterior sampling. The reason is that 
the bootstrap has to generate S sampled datasets, and then fit a model to each one. By contrast, in 
posterior sampling, we only have to “fit” a model once given a single dataset. (Some methods for 
speeding up the bootstrap when applied to massive data sets are discussed in [Kle+11].) 


4.7.4 Confidence intervals 


In frequentist statistics, we use the variability induced by the sampling distribution as a way to 
estimate uncertainty of a parameter estimate. 

In particular, we define a 100(1 — a)% confidence interval for parameter 0 as an estimator that 
returns an interval that captures the true parameter with probability at least 1 — a. Denote the 
estimator by I(D) = (€(D),u(D)). The sampling distribution of this estimator is the distribution 
that is induced by sampling D ~ 6* and then computing I(D). We require that 


Pr(8* € I(D)|D ~) >1—a (4.223) 


It is common to set œ = 0.05, which yields a 95% CI. This means that, if we repeatedly sampled 
data, and compute I (D) for each such dataset, then about 95% of such intervals will contain the true 
parameter 0. 

Let us give an example. Suppose that ĝ = O(D) is an estimator for some parameter with true but 


unknown value 6*. Also, suppose that the sampling distribution of A = 6* — Ê is known. Let ô and 6 


6. This is called the non-parametric bootstrap. There is another variant, called the parametric bootstrap, in 
which we sample each D‘*) from p(an|O(D)); this requires a parametric model of the data. 
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denote its a/2 and 1 — a/2 quantiles, so 

Pr(§ < A<) =Pr( < &* -8<5)=1-a (4.224) 
Rearranging we get 

Pr(ô+8 <6 <64+5)=1-a (4.225) 
Hence we can construct a 100(1 — a)% confidence interval as follows: 

I(D) = (L, U) = (ê(D) + 6(D), ôD) + 6(D)) (4.226) 


In some cases, we can analytically compute the sampling distribution of the above interval. 
However, it is more common to assume a Gaussian approximation to the sampling distribution, as in 
Section 4.7.2. In this case, we have 6 ~ N (0*,se?). Hence we can compute an approximate CI using 


I = (6 — 2qs28e, 6 + Z./288) (4.227) 


where 2/2 is the a/2 quantile of the Gaussian cdf. If we set a = 0.05, we have z/2 = 1.96, which 


justifies the common approximation 6 + 2se. 

If the Gaussian approximation is not a good one, we can use a bootstrap approximation (see 
Section 4.7.3). In particular, we sample S datasets from Ê(D), and apply the estimator to each one 
to get 6(D“)); we then use the empirical distribution of Ôĝ(D) — 6(D“)) as an approximation to the 
sampling distribution of A. We can then use the a/2 and 1 — a/2 quantiles of this distribution to 
derive the CI (see [Was04, p110] for details). 


4.7.5 Caution: Confidence intervals are not credible 


It is commonly believed that a 95% confidence interval J for a parameter estimate 0 given data D 
means that the true parameter lies in this interval with probability 0.95, ie, p(@* € I/D) = 0.95). 
However, this quantity is what a Bayesian credible interval computes (Section 4.6.6), but is not what 
a frequentist confidence interval computes. Instead the frequentist approach just means that the 
procedure for generating CIs will contain the true value 95% of the time. That is, if we repeatedly 
sample datasets D from 6*, and compute their CIs to get I(D), then we have Pr(6* € I(D)) = 0.95, 
as we explain in Section 4.7.4. Thus we see that these concepts are quite different: In the frequentist 
approach, @ is treated as an unknown fixed constant, and the data is treated as random. In the 
Bayesian approach, we treat the data as fixed (since it is known) and the parameter as random (since 
it is unknown). 

This counter-intuitive definition of confidence intervals can lead to bizarre results. Consider the 
following example from [Ber85, p11]. Suppose we draw two integers D = (y1, y2) from 


0.5 ify=0 
p(yl@)=< 0.5 ify=0+1 (4.228) 
0 otherwise 


If 0 = 39, we would expect the following outcomes each with probability 0.25: 
(39, 39), (39, 40), (40, 39), (40, 40) (4.229) 
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Let m = min(y1, y2) and define the following interval: 

[e(D), u(D)]} = [m, m] (4.230) 
For the above samples this yields 

[39,39], [39,39], [39,39], [40, 40] (4.231) 


Hence Equation (4.230) is clearly a 75% CI, since 39 is contained in 3/4 of these intervals. However, 
if we observe D = (39,40) then p(8 = 39|D) = 1.0, so we know that 0 must be 39, yet we only have 
75% “confidence” in this fact. We see that the CI will “cover” the true parameter 75% of the time, 
if we compute multiple Cls from different randomly sampled datasets, but if we just have a single 
observed dataset, and hence a single CI, then the frequentist “coverage” probability can be very 
misleading. 

Another, less contrived, example is as follows. Suppose we want to estimate the parameter 0 
of a Bernoulli distribution. Let 7 = + ie Yn be the sample mean. The MLE is 6 = 7. An 
approximate 95% confidence interval for a Bernoulli parameter is ¥+1.96,/9(1 — y)/N (this is called 
a Wald interval and is based on a Gaussian approximation to the Binomial distribution; compare 
to Equation (4.128)). Now consider a single trial, where N = 1 and yı = 0. The MLE is 0, which 
overfits, as we saw in Section 4.5.1. But our 95% confidence interval is also (0,0), which seems even 
worse. It can be argued that the above flaw is because we approximated the true sampling distribution 
with a Gaussian, or because the sample size was too small, or the parameter “too extreme”. However, 
the Wald interval can behave badly even for large N, and non-extreme parameters [BCDO1]. By 
contrast, a Bayesian credible interval with a non-informative Jeffreys prior behaves in the way we 
would expect. 

Several more interesting examples, along with Python code, can be found at [Van14]. See 
also [Hoe+14; Mor+16; Lyu+20; Cha+19b], who show that many people, including professional 
statisticians, misunderstand and misuse frequentist confidence intervals in practice, whereas Bayesian 
credible intervals do not suffer from these problems. 


4.7.6 The bias-variance tradeoff 


An estimator is a procedure applied to data which returns an estimand. Let 6() be the estimator, and 
6(D) be the estimand. In frequentist statistics, we treat the data as a random variable, drawn from 
some true but unknown distribution, p*(D); this induces a distribution over the estimand, p*(0(D)), 
known as the sampling distribution (see Section 4.7.1). In this section, we discuss two key properties 


of this distribution, its bias and its variance, which we define below. 


4.7.6.1 Bias of an estimator 


The bias of an estimator is defined as 


bias(6(-)) ê E [a(D)] — 9 (4.232) 


where 6* is the true parameter value, and the expectation is wrt “nature’s distribution” p(D|6*). If 
the bias is zero, the estimator is called unbiased. For example, the MLE for a Gaussian mean is 
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unbiased: 


bias(ñ) = E [z] - p = E 


N 
1 5 Nu 


where T is the sample mean. 
However, the MLE for a Gaussian variance, 07, 
of o?. In fact, one can show (Exercise 4.7) that 


N A2: ; : 
= + J n1 (£n —7)?, is not an unbiased estimator 


D lofie = 2 (4.234) 


so the ML estimator slightly underestimates the variance. Intuitively, this is because we “use up” 
one of the data points to estimate the mean, so if we have a sample size of 1, we will estimate the 
variance to be 0. If, however, u is known, the ML estimator is unbiased (see Exercise 4.8). 

Now consider the following estimator 


N 
1 N 
2 A 2 2 
A — = 4.235 
Cunb N-1 2i ) N-1 O mle ( ) 
This is an unbiased estimator, which we can easily prove as follows: 
N N N-1 

7 [a2 2 2 2 
Yi = ——_E = —_ —_9* = 4.236 
[oz nb | N-1 LAA N-1 N o o ( ) 


4.7.6.2 Variance of an estimator 


It seems intuitively reasonable that we want our estimator to be unbiased. However, being unbiased is 
not enough. For example, suppose we want to estimate the mean of a Gaussian from D = {21,..., £N}. 
The estimator that just looks at the first data point, 6(D) = £1, is an unbiased estimator, but will 
generally be further from 6* than the empirical mean 7 (which is also unbiased). So the variance of 
an estimator is also important. 

We define the variance of an estimator as follows: 


2 
v g £ : [é] = ( :|6]) (4.237) 
where the expectation is taken wrt p(D|0*). This measures how much our estimate will change as 
the data changes. We can extend this to a covariance matrix for vector valued estimators. 
Intuitively we would like the variance of our estimator to be as small as possible. Therefore, a 
natural question is: how low can the variance go? A famous result, called the Cramer-Rao lower 
bound, provides a lower bound on the variance of any unbiased estimator. More precisely, let 
X1,...,Xn ~ p(X|6*) and 6= 6(x1, ...;@y) be an unbiased estimator of 6*. Then, under various 


smoothness assumptions on p(X|6*), we have V g > where F(0*) is the Fisher information 


1 
NF(@*)? 
matrix (Section 4.7.2). A proof can be found e.g., in [Ric95, p275]. 

It can be shown that the MLE achieves the Cramer Rao lower bound, and hence has the smallest 


asymptotic variance of any unbiased estimator. Thus MLE is said to be asymptotically optimal. 
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4.7.6.3 The bias-variance tradeoff 


In this section, we discuss a fundamental tradeoff that needs to be made when picking a method 
for parameter estimation, assuming our goal is to minimize the mean squared error (MSE) of our 


estimate. Let 6 = ÊD) denote the estimate, and 0 = E [4] denote the expected value of the estimate 


(as we vary D). (All expectations and variances are wrt p(D|0*), but we drop the explicit conditioning 
for notational brevity.) Then we have 


(6-8)+(@- 6")] Í (4.238) 


| 
(0-2) | + 2(0 — 0") : [6-3] + (8 — 6")? (4.239) 
( 


ct 2 = 
=E|(6- 8) | + 0- 0")? (4.240) 
= v [6] + bias?(ô) (4.241) 
In words, 
MSE = variance + bias? (4.242) 


This is called the bias-variance tradeoff (see e.g., [GBD92]). What it means is that it might be 
wise to use a biased estimator, so long as it reduces our variance by more than the square of the bias, 
assuming our goal is to minimize squared error. 


4.7.6.4 Example: MAP estimator for a Gaussian mean 


Let us give an example, based on [Hof09, p79]. Suppose we want to estimate the mean of a Gaussian 
from æ = (x1,...,2N). We assume the data is sampled from x, ~ N(0* = 1,07). An obvious 
estimate is the MLE. This has a bias of 0 and a variance of 


o 
z|0*] = — 4.24 
Vie = z (4.243) 
But we could also use a MAP estimate. In Section 4.6.4.2, we show that the MAP estimate under a 
Gaussian prior of the form M (0o, o° /ko) is given by 


A N , Ko 


} 69 = WT 1—w)@ 4.244 
a ETA wT + (1 — w)o ( ) 


where 0 < w < 1 controls how much we trust the MLE compared to our prior. The bias and variance 
are given by 


z [z] — 0* = w0* + (1 — w)bo — &* = (1 — w) (bo — 8*) (4.245) 
V [z] = we (4.246) 
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_ sampling distribution, truth = 1.0, prior = 0.0, n = 5 MSE of postmean / MSE of MLE 
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Figure 4.24: Left: Sampling distribution of the MAP estimate (equivalent to the posterior mean) under a 
N (00 = 0,07/ko) prior with different prior strengths ko. (If we set k = 0, the MAP estimate reduces to the 
MLE.) The data isn =5 samples drawn from N(6* = 1,07 = 1). Right: MSE relative to that of the MLE 
versus sample size. Adapted from Figure 5.6 of [Hof09]. Generated by samplingDistributionGaussianShrink- 


age.ipynb. 


So although the MAP estimate is biased (assuming w < 1), it has lower variance. 

Let us assume that our prior is slightly misspecified, so we use 09 = 0, whereas the truth is 6* = 1. 
In Figure 4.24(a), we see that the sampling distribution of the MAP estimate for kp > 0 is biased 
away from the truth, but has lower variance (is narrower) than that of the MLE. 

In Figure 4.24(b), we plot mse(%)/mse(%) vs N. We see that the MAP estimate has lower MSE 
than the MLE for ko € {1,2}. The case kọ = 0 corresponds to the MLE, and the case ko = 3 
corresponds to a strong prior, which hurts performance because the prior mean is wrong. Thus we 
see that, provided the prior strength is properly “tuned”, a MAP estimate can outperform an ML 
estimate in terms of minimizing MSE. 


4.7.6.5 Example: MAP estimator for linear regression 


Another important example of the bias-variance tradeoff arises in ridge regression, which we discuss 
in Section 11.3. In brief, this corresponds to MAP estimation for linear regression under a Gaussian 
prior, p(w) = N(w]|0,~1I) The zero-mean prior encourages the weights to be small, which reduces 
overfitting; the precision term, À, controls the strength of this prior. Setting A = 0 results in the 
MLE; using à > 0 results in a biased estimate. To illustrate the effect on the variance, consider a 
simple example where we fit a 1d ridge regression model using 2 different values of A. Figure 4.25 on 
the left plots each individual fitted curve, and on the right plots the average fitted curve. We see 
that as we increase the strength of the regularizer, the variance decreases, but the bias increases. 
See also Figure 4.26 where we give a cartoon sketch of the bias variance tradeoff in terms of model 


complexity. 
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Figure 4.25: Illustration of bias-variance tradeoff for ridge regression. We generate 100 data sets from the 
true function, shown in solid green. Left: we plot the regularized fit for 20 different data sets. We use linear 
regression with a Gaussian RBF expansion, with 25 centers evenly spread over the [0,1] interval. Right: we 
plot the average of the fits, averaged over all 100 datasets. Top row: strongly regularized: we see that the 
individual fits are similar to each other (low variance), but the average is far from the truth (high bias). 
Bottom row: lightly regularized: we see that the individual fits are quite different from each other (high 
variance), but the average is close to the truth (low bias). Adapted from [Bis06] Figure 3.5. Generated by 
bias VarModelComplexity3.ipynb. 


Low Variance High Variance 


Figure 4.26: Cartoon illustration of the bias variance tradeoff. From http: // scott. fortmann-roe. com/ 
docs/ BiasVariance. html. Used with kind permission of Scott Fortmann-Roe. 
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4.7.6.6 Bias-variance tradeoff for classification 


If we use 0-1 loss instead of squared error, the frequentist risk is no longer expressible as squared bias 
plus variance. In fact, one can show (Exercise 7.2 of [HTF09]) that the bias and variance combine 
multiplicatively. If the estimate is on the correct side of the decision boundary, then the bias is 
negative, and decreasing the variance will decrease the misclassification rate. But if the estimate 
is on the wrong side of the decision boundary, then the bias is positive, so it pays to increase the 
variance [Fri97a]. This little known fact illustrates that the bias-variance tradeoff is not very useful 
for classification. It is better to focus on expected loss, not directly on bias and variance. We can 
approximate the expected loss using cross validation, as we discuss in Section 4.5.5. 


4.8 Exercises 


Exercise 4.1 [MLE for the univariate Gaussian *] 


Show that the MLE for a univariate Gaussian is given by 


1x 
Lo (4.247) 
ai lt me 
oe: 2 (0n — ji) (4.248) 


Exercise 4.2 [MAP estimation for 1D Gaussians *] 


(Source: Jaakkola.) 


Consider samples z1, ..., £n from a Gaussian random variable with known variance o° and unknown mean p. 
We further assume a prior distribution (also Gaussian) over the mean, u ~ N(m, s”), with fixed mean m and 
fixed variance s?. Thus the only unknown is pu. 


a. Calculate the MAP estimate fiazap. You can state the result without proof. Alternatively, with a lot 
more work, you can compute derivatives of the log posterior, set to zero and solve. 


b. Show that as the number of samples n increase, the MAP estimate converges to the maximum likelihood 
estimate. 


c. Suppose n is small and fixed. What does the MAP estimator converge to if we increase the prior variance 
s“? 


d. Suppose n is small and fixed. What does the MAP estimator converge to if we decrease the prior variance 
s“? 


Exercise 4.3 [Gaussian posterior credible interval] 


(Source: DeGroot.) Let X ~ N(u,o? = 4) where u is unknown but has prior p ~ N (0,06 = 9). The 


posterior after seeing n samples is u ~ N (Jin, 02). (This is called a credible interval, and is the Bayesian 
analog of a confidence interval.) How big does n have to be to ensure 


p(l < un < ulD) > 0.95 (4.249) 


where (£, u) is an interval (centered on fin) of width 1 and D is the data? Hint: recall that 95% of the 
probability mass of a Gaussian is within +1.960 of the mean. 
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Exercise 4.4 [BIC for Gaussians *| 
(Source: Jaakkola.) 


The Bayesian information criterion (BIC) is a penalized log-likelihood function that can be used for model 
selection. It is defined as 


BIC = log p(D\@uz) — 4 log(N) (4.250) 


where d is the number of free parameters in the model and N is the number of samples. In this question, 
we will see how to use this to choose between a full covariance Gaussian and a Gaussian with a diagonal 
covariance. Obviously a full covariance Gaussian has higher likelihood, but it may not be “worth” the extra 
parameters if the improvement over a diagonal covariance matrix is too small. So we use the BIC score to 
choose the model. 


We can write 


in N Ku aA N a 
log p(D|B, fa) = — tr (= 18) - 5 les(l)) (4.251) 
1 N 
af ee 
S N 2 T)(£i — T) (4.252) 


where § is the scatter matrix (empirical covariance), the trace of a matrix is the sum of its diagonals, and we 
have used the trace trick. 


a. Derive the BIC score for a Gaussian in D dimensions with full covariance matrix. Simplify your answer as 
much as possible, exploiting the form of the MLE. Be sure to specify the number of free parameters d. 


b. Derive the BIC score for a Gaussian in D dimensions with a diagonal covariance matrix. Be sure to specify 


the number of free parameters d. Hint: for the digaonal case, the ML estimate of X is the same as Marr 
except the off-diagonal terms are zero: 


Saiag = diag(Sart(1,1),..., Sart (D, D)) (4.253) 


Exercise 4.5 [BIC for a 2d discrete distribution] 
(Source: Jaakkola.) 


Let x € {0,1} denote the result of a coin toss (x = 0 for tails, x = 1 for heads). The coin is potentially biased, 
so that heads occurs with probability 01. Suppose that someone else observes the coin flip and reports to you 
the outcome, y. But this person is unreliable and only reports the result correctly with probability 02; i.e., 
p(y|x, 82) is given by 


|y=0 y=1 
x 0 02 1— 62 
x 1 1— 02 02 


Assume that 02 is independent of x and 61. 


a. Write down the joint probability distribution p(x, y|@) as a 2 x 2 table, in terms of 0 = (01, 62). 


b. Suppose have the following dataset: æ = (1,1,0,1,1,0,0), y = (1,0,0,0,1,0,1). What are the MLEs for 
0, and 02? Justify your answer. Hint: note that the likelihood function factorizes, 


p(x, y\@) = ply|x, 02)p(x|61) (4.254) 


What is p(D|@, Mz) where M2 denotes this 2-parameter model? (You may leave your answer in fractional 
form if you wish.) 
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c. Now consider a model with 4 parameters, 0 = (00,0, 90,1, 91,0, 61,1), representing p(x, y|@) = O2,y. (Only 
3 of these parameters are free to vary, since they must sum to one.) What is the MLE of 0? What is 


p(D|@, M1) where M4 denotes this 4-parameter model? 


d. Suppose we are not sure which model is correct. We compute the leave-one-out cross validated log 
likelihood of the 2-parameter model and the 4-parameter model as follows: 


L(m) = } [log p(z yim, 6(D-1)) (4.255) 


and 6(D_;)) denotes the MLE computed on D excluding row i. Which model will CV pick and why? 
Hint: notice how the table of counts changes when you omit each training case one at a time. 


e. Recall that an alternative to CV is to use the BIC score, defined as 


log N (4.256) 


BIC(M, D) ê log p(D|@artez) — aor) 


where dof(M) is the number of free parameters in the model, Compute the BIC scores for both models 
(use log base e). Which model does BIC prefer? 


Exercise 4.6 [A mixture of conjugate priors is conjugate *] 


Consider a mixture prior 
p(0) = So plz = k)p(6|z = k) (4.257) 
k 


where each p(6|z = k) is conjugate to the likelihood. Prove that this is a conjugate prior. 


Exercise 4.7 [ML estimator 24e is biased] 


Show that ôĝiL e = + eile, — ji) is a biased estimator of 0”, i.e., show 


Ex, Xan No) [67(K1,---,Xn)] Ao? 


Hint: note that X1,..., Xw are independent, and use the fact that the expectation of a product of independent 
random variables is the product of the expectations. 


Exercise 4.8 [Estimation of o? when p is known *| 


Suppose we sample 71,...,2n ~ N (u, a”) where u is a known constant. Derive an expression for the MLE 
for ø? in this case. Is it unbiased? 


Exercise 4.9 [Variance and MSE of estimators for Gaussian variance *] 


Prove that the standard error for the MLE for a Gaussian variance is 


/V [02i] = 20 De? (4.258) 


Hint: use the fact that 


NSH 
aa unb ~ XN=13 (4.259) 


and that V [xi] = 2(N — 1). Finally, show that MSE(o2,,) = z+ 04 and MSE(o%1.) = yo". 
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5.1 Bayesian decision theory 


Bayesian inference provides the optimal way to update our beliefs about hidden quantities H given 
observed data X = x by computing the posterior p(H|x). However, at the end of the day, we 
need to turn our beliefs into actions that we can perform in the world. How can we decide which 
action is best? This is where Bayesian decision theory comes in. In this chapter, we give a brief 
introduction. For more details, see e.g., [DeG70; KWW22]. 


5.1.1 Basics 


In decision theory, we assume the decision maker, or agent, has a set of possible actions, A, to 
choose from. For example, consider the case of a hypothetical doctor treating someone who may 
have COVID-19. Suppose the actions are to do nothing, or to give the patient an expensive drug 
with bad side effects, but which can save their life. 

Each of these actions has costs and benefits, which will depend on the underlying state of nature 
H € H. We can encode this information into a loss function (h,a), that specifies the loss we incur 
if we take action a € A when the state of nature is h € H. 

For example, suppose the state is defined by the age of the patient (young vs old), and whether 
they have COVID-19 or not. Note that the age can be observed directly, but the disease state must 
be inferred from noisy observations, as we discussed in Section 2.3. Thus the state is partially 
observed. 

Let us assume that the cost of administering a drug is the same, no matter what the state of the 
patient is. However, the benefits will differ. If the patient is young, we expect them to live a long 
time, so the cost of not giving the drug if they have COVID-19 is high; but if the patient is old, they 
have fewer years to live, so the cost of not giving the drug if they have COVID-19 is arguably less 
(especially in view of the side effects). In medical circles, a common unit of cost is quality-adjusted 
life years or QALY. Suppose that the expected QALY for a young person is 60, and for an old 
person is 10. Let us assume the drug costs the equivalent of 8 QALY, due to induced pain and 
suffering from side effects. Then we get the loss matrix shown in Table 5.1. 

These numbers reflect relative costs and benefits, and will depend on many factors. The numbers 
can be derived by asking the decision maker about their preferences about different possible 
outcomes. It is a theorem of decision theory that any consistent set of preferences can be converted 
into an ordinal cost scale (see e.g., https: //en.wikipedia. org/wiki/Preference_(economics)). 

Once we have specified the loss function, we can compute the posterior expected loss or risk 
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State | Nothing Drugs 
No COVID-19, young | 0 8 
COVID-19, young 60 8 
No COVID-19, old 0 8 
COVID-19, old 10 8 


Table 5.1: Hypothetical loss matrix for a decision maker, where there are 4 states of nature, and 2 possible 
actions. 


test age pr(covid) cost-noop  cost-drugs action 


0 0 0.01 0.84 8.00 0 
0 1 0.01 0.14 8.00 0 
1 0 0.80 47.73 8.00 1 
1 1 0.80 7.95 8.00 0 


Table 5.2: Optimal policy for treating COVID-19 patients for each possible observation. 


for each possible action a given all the relevant evidence, which may be a single datum æ or an entire 
data set D, depending on the problem: 


plalz) * Epcnja) Elh, a)] = D> Kh, a)p(hjæ) (5.1) 
hEH 


The optimal policy 7*(x), also called the Bayes estimator or Bayes decision rule ô* (æ), 
specifies what action to take when presented with evidence x so as to minimize the risk: 


m* (a) = argmin Ep(h|a) [e(h, a)] (5.2) 
acA 


An alternative, but equivalent, way of stating this result is as follows. Let us define a utility 
function U (h,a) to be the desirability of each possible action in each possible state. If we set 
U(h, a) = —€(h, a), then the optimal policy is as follows: 


m* (a) = argmax Ep [U (h, a)] (5.3) 
acA 
This is called the maximum expected utility principle. 

Let us return to our COVID-19 example. The observation æ consists of the age (young or old) 
and the test result (positive or negative). Using the results from Section 2.3.1 on Bayes rule for 
COVID-19 diagnosis, we can convert the test result into a distribution over disease states (i.e., 
compute the probability the patient has COVID-19 or not). Given this belief state, and the loss 
matrix in Table 5.1, we can compute the optimal policy for each possible observation, as shown in 
Table 5.2. 

We see from Table 5.2 that the drug should only be given to young people who test positive. If, 
however, we reduce the cost of the drug from 8 units to 5, then the optimal policy changes: in this 
case, we should give the drug to everyone who tests positive. The policy can also change depending 
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on the reliability of the test. For example, if we increase the sensitivity from 0.875 to 0.975, then 
the probability that someone has COVID-19 if they test positive increases from 0.80 to 0.81, which 
changes the optimal policy to be one in which we should administer the drug to everyone who tests 
positive, even if the drug costs 8 QALY. (See dtheory.ipynb for the code to reproduce this example.) 
So far, we have implicitly assumed that the agent is risk neutral. This means that their decision 
is not affected by the degree of certainty in a set of outcomes. For example, such an agent would be 
indifferent between getting $50 for sure, or a 50% chance of $100 or $0. By contrast, a risk averse 
agent would choose the first. We can generalize the framework of Bayesian decision theory to risk 
sensitive applications, but we do not pursue the matter here. (See e.g., [Cho+15] for details.) 


5.1.2 Classification problems 


In this section, we use Bayesian decision theory to decide the optimal class label to predict given an 
observed input a € X. 


5.1.2.1 Zero-one loss 


Suppose the states of nature correspond to class labels, so H = Y = {1,...,C}. Furthermore, 
suppose the actions also correspond to class labels, so A = YV. In this setting, a very commonly used 
loss function is the zero-one loss (0; (y*, 0), defined as follows: 


y=0 9 


y* =0 1 (5.4) 
y*=1 0 


We can write this more concisely as follows: 

loi(y*, 0) = I (y* #9) (5.5) 
In this case, the posterior expected loss is 

plôlæ) = pw Ay") = 1— ply" = gla) (5.6) 
Hence the action that minimizes the expected loss is to choose the most probable label: 


m(x) = argmax p(y|a) (5.7) 
yey 


This corresponds to the mode of the posterior distribution, also known as the maximum a 
posteriori or MAP estimate. 


5.1.2.2 Cost-sensitive classification 


Consider a binary classification problem where the loss function is (y*, 0) is as follows: 


loo £01 
5.8 
( lio bu ) (5.8) 
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Let po = p(y* = 0|x) and pı = 1 — po. Thus we should choose label ĝ = 0 iff 


Loopo + L10p1 < lo1po + £1171 (5.9) 
If Zoo = 11 = 0, this simplifies to 


lor 


— .1 
m Lloi + £10 ey) 


Now suppose ¢19 = c€o1, so a false negative costs c times more than a false positive. The decision rule 
further simplifies to the following: pick a = 0 iff pı < 1/(1 + c). For example, if a false negative costs 
twice as much as false positive, so c = 2, then we use a decision threshold of 1/3 before declaring a 
positive. 


5.1.2.3 Classification with the “reject” option 


In some cases, we may able to say “I don’t know” instead of returning an answer that we don’t really 
trust; this is called picking the reject option (see e.g., [BW08]). This is particularly important in 
domains such as medicine and finance where we may be risk averse. 

We can formalize the reject option as follows. Suppose the states of nature are H = Y = {1,...,C}, 
and the actions are A = YU {0}, where action 0 represents the reject action. Now define the following 
loss function: 


0 ify* =aand ae€ {1,...,C} 
Lly“ a) = Ar ifa=0 (5.11) 
Àe otherwise 


where Ar is the cost of the reject action, and Ae is the cost of a classification error. Exercise 5.1 
asks you to show that the optimal action is to pick the reject action if the most probable class has 
a probability below A* = 1 — àr, otherwise you should just pick the most probable class. In other 
words, the optimal policy is as follows: 


wal, o (5.12) 
reject otherwise 


where 
y = argmax p(y|zx) (5.13) 
yE{1,...,C} 

Pani 5.14 

p =ply"|z) er g oe (5.14) 
A 

X =1- 2 5.15 
> (5.15) 


See Figure 5.1 for an illustration. 

One interesting application of the reject option arises when playing the TV game show Jeopardy. In 
this game, contestants have to solve various word puzzles and answer a variety of trivia questions, but 
if they answer incorrectly, they lose money. In 2011, IBM unveiled a computer system called Watson 
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P(y=1| X) 


Ply = 2 |X) 


1.0 
threshold 


0.0 


Reject 
Region 


Figure 5.1: For some regions of input space, where the class posteriors are uncertain, we may prefer not to 
choose class 1 or 2; instead we may prefer the reject option. Adapted from Figure 1.26 of [Bis06]. 


Estimate | Row sum 
0 1 
0) TN FP N 
Tth | ENC TP P 
Col. sum Ñ Ê 


Table 5.8: Class confusion matrix for a binary classification problem. TP is the number of true positives, FP 
is the number of false positives, TN is the number of true negatives, FN is the number of false negatives, P is 


the true number of positives, Ê is the predicted number of positives, N is the true number of negatives, N is 
the predicted number of negatives. 


which beat the top human Jeopardy champion. Watson uses a variety of interesting techniques 
[Fer+10], but the most pertinent one for our present discussion is that it contains a module that 
estimates how confident it is of its answer. The system only chooses to “buzz in” its answer if 
sufficiently confident it is correct. 

For some other methods and applications, see e.g., [Cor+16; GEY19]. 


5.1.3 ROC curves 


In Section 5.1.2.2, we showed that we can pick the optimal label in a binary classification problem 
by thresholding the probability using a value 7, derived from the relative cost of a false positive 
and false negative. Instead of picking a single threshold, we can consider using a set of different 
thresholds, and comparing the resulting performance, as we discuss below. 
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Estimate 
i 
TN/N=TNR=Spec FP/N =FPR=Type I = Fallout 
FN/P=FNR=Miss=Type II TP/P=TPR=Sens=Recall 


Truth A 


Table 5.4: Class confusion matrix for a binary classification problem normalized per row to get p(gly). 
Abbreviations: TNR = true negative rate, Spec = specificity, FPR = false positive rate, FNR = false negative 
rate, Miss = miss rate, TPR = true positive rate, Sens = sensitivity. Note FNR=1-TPR and FPR=1-TNR. 


Estimate 
0 1 
0 | TN/N=NPV FP/P=FDR 
Truth Ps i 
1 | FN/N=FOR  TP/P=Prec=PPV 


Table 5.5: Class confusion matrix for a binary classification problem normalized per column to get p(y|g). 
Abbreviations: NPV = negative predictive value, FDR = false discovery rate, FOR = false omission rate, 
PPV = positive predictive value, Prec = precision. Note that FOR=1-NPV and FDR=1-PPV. 


5.1.3.1 Class confusion matrices 


For any fixed threshold 7, we consider the following decision rule: 
r(x) =I (p(y = 1x) > 1-7) (5.16) 


We can compute the empirical number of false positives (FP) that arise from using this policy on a 
set of N labeled examples as follows: 


N 
FP, = X 11 (Gr(an) = 1, Ym = 0) (5-17) 


n=1 


Similarly, we can compute the empirical number of false negatives (FN), true positives (TP), and 
true negatives (TN). We can store these results in a 2 x 2 class confusion matrix C, where Cj; is 
the number of times an item with true class label 7 was (mis)classified as having label j. In the case 
of binary classification problems, the resulting matrix will look like Table 5.3. 

From this table, we can compute p(j|y) or p(y|g), depending on whether we normalize across the 
rows or columns. We can derive various summary statistics from these distributions, as summarized in 
Table 5.4 and Table 5.5. For example, the true positive rate (TPR), also known as the sensitivity, 
recall or hit rate, is defined as 


TP, 
TP, +FN, 


TPR, = (9 = lly = 1,7) (5.18) 


and the false positive rate (FPR), also called the false alarm rate, or the type I error rate, is 
defined as 


FP, 
FP,+TN, 


FPR, = p(y = lly = 0,7) (5.19) 
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TPR 
precision 


(a) (b) 


Figure 5.2: (a) ROC curves for two hypothetical classification systems. The red curve for system A is better 
than the blue curve for system B. We plot the true positive rate (TPR) vs the false positive rate (FPR) as we 
vary the threshold Tr. We also indicate the equal error rate (EER) with the red and blue dots, and the area 
under the curve (AUC) for classifier B by the shaded area. Generated by roc_plot.ipynb. (b) A precision-recall 
curve for two hypothetical classification systems. The red curve for system A is better than the blue curve for 
system B. Generated by pr_plot.ipynb. 


We can now plot the TPR vs FPR as an implicit function of 7. This is called a receiver operating 
characteristic or ROC curve. See Figure 5.2(a) for an example. 


5.1.3.2 Summarizing ROC curves as a scalar 


The quality of a ROC curve is often summarized as a single number using the area under the 
curve or AUC. Higher AUC scores are better; the maximum is obviously 1. Another summary 
statistic that is used is the equal error rate or EER, also called the cross-over rate, defined as 
the value which satisfies FPR = FNR. Since FNR=1-TPR, we can compute the EER by drawing a 
line from the top left to the bottom right and seeing where it intersects the ROC curve (see points A 
and B in Figure 5.2(a)). Lower EER scores are better; the minimum is obviously 0 (corresponding to 
the top left corner). 


5.1.3.3 Class imbalance 


In some problems, there is severe class imbalance. For example, in information retrieval, the set of 
negatives (irrelevant items) is usually much larger than the set of positives (relevant items). The ROC 
curve is unaffected by class imbalance, as the TPR and FPR are fractions within the positives and 
negatives, respectively. However, the usefulness of an ROC curve may be reduced in such cases, since 
a large change in the absolute number of false positives will not change the false positive rate very 
much, since FPR is divided by FP+TN (see e.g., [SR15] for discussion). Thus all the “action” happens 
in the extreme left part of the curve. In such cases, we may choose to use other ways of summarizing 
the class confusion matrix, such as precision-recall curves, which we discuss in Section 5.1.4. 
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5.1.4 Precision-recall curves 


In some problems, the notion of a “negative” is not well-defined. For example, consider detecting 
objects in images: if the detector works by classifying patches, then the number of patches examined 
— and hence the number of true negatives — is a parameter of the algorithm, not part of the problem 
definition. Similarly, information retrieval systems usually get to choose the initial set of candidate 
items, which are then ranked for relevance; by specifying a cutoff, we can partition this into a positive 
and negative set, but note that the size of the negative set depends on the total number of items 
retrieved, which is an algorithm parameter, not part of the problem specification. 

In these kinds of situations, we may choose to use a precision-recall curve to summarize the 
performance of our system, as we explain below. (See [DG06] for a more detailed discussion of the 
connection between ROC curves and PR curves.) 


5.1.4.1 Computing precision and recall 


The key idea is to replace the FPR with a quantity that is computed just from positives, namely the 
precision: 

TP; 
TP, + FP, 


P(r) = ply = 19 = 1,7) (5.20) 


The precision measures what fraction of our detections are actually positive. We can compare this to 
the recall (which is the same as the TPR), which measures what fraction of the positives we actually 
detected: 


TP, 
TP, + FN, 


R(T) Ê pĝ = 1y = 1,7) (5.21) 


If n € {0,1} is the predicted label, and yn € {0,1} is the true label, we can estimate precision 
and recall using 


>) Xn Ynn 
P(r) = si (5.22) 
R(t) = Dun Ynn (5.23) 


We can now plot the precision vs recall as we vary the threshold r. See Figure 5.2(b). Hugging 
the top right is the best one can do. 


5.1.4.2 Summarizing PR curves as a scalar 


The PR curve can be summarized as a single number in several ways. First, we can quote the 
precision for a fixed recall level, such as the precision of the first K = 10 entities recalled. This 
is called the precision at K score. Alternatively, we can compute the area under the PR curve. 
However, it is possible that the precision does not drop monotonically with recall. For example, 
suppose a classifier has 90% precision at 10% recall, and 96% precision at 20% recall. In this case, 
rather than measuring the precision at a recall of 10%, we should measure the maximum precision 
we can achieve with at least a recall of 10% (which would be 96%). This is called the interpolated 
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precision. The average of the interpolated precisions is called the average precision; it is equal 
to the area under the interpolated PR curve, but may not be equal to the area under the raw PR 
curve.! The mean average precision or mAP is the mean of the AP over a set of different PR 
curves. 


5.1.4.3 F-scores 


For a fixed threshold, corresponding to a single point on the PR curve, we can compute a single 
precision and recall value, which we will denote by P and R. These are often combined into a single 
statistic called the F5, defined as follows:* 


1 dis “a Be 1 
= H 5.24 
Fe 14+6?P 14+67R ae) 
or equivalently 
P-R (1+ 6?)TP 

F = (1+ 8? = 5.25 
p+ ap rR (1+ 8)TP+BFN+FP 28) 

If we set 8 = 1, we get the harmonic mean of precision and recall: 

1 1/1 1 
z 5.26 
F 2 (5 = z) (5.26) 
2 R TP 

E P (5.27) 


1/R+1/P_ “P+R TP+i(FP+FN) 


To understand why we use the harmonic mean instead of the arithmetic mean, (P + R)/2, consider 
the following scenario. Suppose we recall all entries, so ĝ„ = 1 for all n, and R = 1. In this case, the 
precision P will be given by the prevalence, p(y = 1) = Zale) Suppose the prevalence is low, 
say p(y = 1) = 1074. The arithmetic mean of P and R is given by (P + R)/2 = (1074 + 1)/2 ~ 50%. 
By contrast, the harmonic mean of this strategy is only 2x10 ix ~ 0.02%. In general, the harmonic 
mean is more conservative, and requires both precision and recall to be high. 

Using F} score weights precision and recall equally. However, if recall is more important, we may 


use 8 = 2, and if precision is more important, we may use ( = 0.5. 


5.1.4.4 Class imbalance 


ROC curves are insensitive to class imbalance, but PR curves are not, as noted in [Wil20]. To see this, 
let the fraction of positives in the dataset be 7 = P/(P+N), and define the ratio r = P/N = 7/(1—7). 
Let n = P + N be the population size. ROC curves are not affected by changes in r, since the TPR 
is defined as a ratio within the positive examples, and FPR is defined as a ratio within the negative 
examples. This means it does not matter which class we define as positive, and which we define as 
negative. 


1. For details, see https: //sanchom. wordpress. com/tag/average-precision/. 
2. We follow the notation from https: //en.wikipedia.org/wiki/F-score#F/CE/B2. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


176 Chapter 5. Decision Theory 


Now consider PR curves. The precision can be written as 


TP P-TPR o TPR 
TP+FP P-TPR+N-FPR TPR+iFPR 


Prec = (5.28) 


Thus Prec > 1 as 7 —> 1 and r > ov, and Prec > 0 as m > 0 and r > 0. For example, if we 
change from a balanced problem where r = 0.5 to an imbalanced problem where r = 0.1 (so positives 
are rarer), the precision at each threshold will drop, and the recall (aka TPR) will stay the same, 
so the overall PR curve will be lower. Thus if we have multiple binary problems with different 
prevalences (e.g., object detection of common or rare objects), we should be careful when averaging 
their precisions [HCD12]. 
The F-score is also affected by class imbalance. To see this, note that we can rewrite the F-score 

as follows: 

1 1 1 Be 1 

Fo 1+62P 1+2R (5.29) 
1 TPR+3FPR @ 1 


METE TPR "1+ 62TPR (5,30) 
(1+ B)TPR 

F; = 5.31 

f TPR+1FPR+ BP ee 


5.1.5 Regression problems 


So far, we have considered the case where there are a finite number of actions A and states of nature 
H. In this section, we consider the case where the set of actions and states are both equal to the real 
line, A = H = R. We will specify various commonly used loss functions for this case (which can be 
extended to R? by computing the loss elementwise.) The resulting decision rules can be used to 
compute the optimal parameters for an estimator to return, or the optimal action for a robot to take, 
etc. 


5.1.5.1 L2 loss 


The most common loss for continuous states and actions is the £2 loss, also called squared error 
or quadratic loss, which is defined as follows: 


fo(h,a) = (h — a)” (5.32) 


In this case, the risk is given by 


p(a|x) =E [(h — a)’ |æ] = E [h?|a] — 2aE [hla] + a? (5.33) 


The optimal action must satisfy the condition that the derivative of the risk (at that point) is zero 
(as explained in Chapter 8). Hence the optimal action is to pick the posterior mean: 


ajz) = —2E [h|a]+2a=0 => r(x) = E [hje] = Jh ptiieyan (5.34) 


am 
da” 
This is often called the minimum mean squared error estimate or MMSE estimate. 
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Figure 5.3: Illustration of L2, £1, and Huber loss functions with 6 = 1.5. Generated by huberLossPlot.ipynb. 


5.1.5.2 L1 loss 


The 4 loss penalizes deviations from the truth quadratically, and thus is sensitive to outliers. A 
more robust alternative is the absolute or 44 loss 


t (h,a) = |h — al (5.35) 


This is sketched in Figure 5.3. Exercise 5.4 asks you to show that the optimal estimate is the 
posterior median, i.e., a value a such that Pr(h < ala) = Pr(h > alx) = 0.5. We can use this for 
robust regression as discussed in Section 11.6.1. 


5.1.5.3 Huber loss 


Another robust loss function is the Huber loss [Hub64], defined as follows: 


r? /2 if |r] <6 
Hla) = A x 527 iles ee) 
where r = h — a. This is equivalent to l2 for errors that are smaller than ô, and is equivalent to 41 
for larger errors. See Figure 5.3 for a plot. We can use this for robust regression as discussed in 
Section 11.6.3. 


5.1.6 Probabilistic prediction problems 


In Section 5.1.2, we assumed the set of possible actions was to pick a single class label (or possibly the 
“reject” or “do not know” action). In Section 5.1.5, we assumed the set of possible actions was to pick 
a real valued scalar. In this section, we assume the set of possible actions is to pick a probability 
distribution over some value of interest. That is, we want to perform probabilistic prediction 
or probabilistic forecasting, rather than predicting a specific value. More precisely, we assume 
the true “state of nature” is a distribution, h = p(Y |x), the action is another distribution, a = q(Y |x), 
and we want to pick q to minimize E [¢(p, q)] for a given x. We discuss various possible loss functions 
below. 
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5.1.6.1 KL, cross-entropy and log-loss 


A common form of loss functions for comparing two distributions is the Kullback Leibler diver- 
gence, or KL divergence, which is defined as follows: 


p(y) 


Dru (p || a) £X ` ply ) log == at) 


yey 


(5.37) 


(We have assumed the variable y is discrete, for notational simplicity, but this can be generalized 
to real-valued variables.) In Section 6.2, we show that the KL divergence satisfies the following 
properties: Dx (p || q) > 0 with equality iff p = q. Note that it is an asymmetric function of its 
arguments. 

We can expand the KL as follows: 


Dra (p || a) = X` py) log p(y) — X ply) log aly (5.38) 
yey yey 

H(p) ê — Tro) os (5.40) 

Hee (p,q jê So) )logq(y (5.41) 


The H(p) term is known as the entropy. This is a measure of uncertainty or variance of p; it is 
maximal if p is uniform, and is 0 if p is a degenerate or deterministic delta function. Entropy is often 
used in the field of information theory, which is concerned with optimal ways of compressing and 
communicating data (see Chapter 6). The optimal coding scheme will allocate fewer bits to more 
frequent symbols (i.e., values of Y for which p(y) is large), and more bits to less frequent symbols. A 
key result states that the number of bits needed to compress a dataset generated by a distribution p 
is at least H(p); the entropy therefore provides a lower bound on the degree to which we can compress 
data without losing information. The Hee(p, q) term is known as the cross-entropy. This measures 
the expected number of bits we need to use to compress a dataset coming from distribution p if we 
design our code using distribution q. Thus the KL is the extra number of bits we need to use to 
compress the data due to using the incorrect distribution q. If the KL is zero, it means that we can 
correctly predict the probabilities of all possible future events, and thus we have learned to predict 
the future as well as an “oracle” that has access to the true distribution p. 

To find the optimal distribution to use when predicting future data, we can minimize Degu (p || q). 
Since H(p) is a constant wrt q, it can be ignored, and thus we can equivalently minimize the 
cross-entropy: 


q (Y |x) = argent Me E E |e) (5.42) 


Now consider the special case in which the true state of nature is a degenerate distribution, which 
puts all its mass on a single outcome, say c, i.e., h = p(Y |x) = I (Y =c). This is often called 
a “one-hot” distribution, since it turns “on” the c’th element of the vector, and leaves the other 
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elements “off”, as shown in Figure 2.1. In this case, the cross entropy becomes 


Hee(5(Y = ¢),4) = — $ ôly = c) log gly) = — log q(c) (5.43) 
yey 


This is known as the log loss of the predictive distribution q when given target label c. 


5.1.6.2 Proper scoring rules 


Cross-entropy loss is a very common choice for probabilistic forecasting, but is not the only possible 
metric. The key property we desire is that the loss function is minimized iff the decision maker picks 
the distribution q that matches the true distribution p, i.e., (p,p) < (p,q), with equality iff p = q. 
Such a loss function £ is called a proper scoring rule [GRO7]. 

We can show that cross-entropy loss is a proper scoring rule by virtue of the fact that Dx (p || p) < 
Dx (p || q). However, the log p(y)/q(y) term can be quite sensitive to errors for low probability 
events [QC+06]. A common alternative is to use the Brier score [Bri50], which is defined as follows 
(for a discrete distribution with C values): 


(p,q) 2 & FaU = de) — ply = e2))? (5.44) 


This is just the squared error of the predictive distribution compared to the true distribution, when 
viewed as vectors. Since it is based on squared error, the Brier score is less sensitive to extremely 
rare or extremely common classes. Fortunately, it is also a proper scoring rule. 


5.2 Choosing the “right” model 


In this section, we consider the setting in which we have several candidate (parametric) models (e.g., 
neural networks with different numbers of layers), and we want to choose the “right” one. This can 
be tackled using tools from Bayesian decision theory. 


5.2.1 Bayesian hypothesis testing 


Suppose we have two hypotheses or models, commonly called the null hypothesis, Mo, and the 
alternative hypothesis, Mı, and we want to know which one is more likely to be true. This is 
called hypothesis testing. 

If we use 0-1 loss, the optimal decision is to pick the alternative hypothesis iff p(Mi|D) > p(Mo|D), 
or equivalently, if p(Mı|D)/p(Mo|D) > 1. If we use a uniform prior, p(Mo) = p(Mı) = 0.5, the 
decision rule becomes: select Mı iff p(D|M1)/p(D|Mo) > 1. This quantity, which is the ratio of 
marginal likelihoods of the two models, is known as the Bayes factor: 


a P(D|M) 


B ae 
1 p(D|Mo) 


(5.45) 


This is like a likelihood ratio, except we integrate out the parameters, which allows us to compare 
models of different complexity, due to the Bayesian Occam’s razor effect explained in Section 5.2.3. 
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Bayes factor BF(1, 0) Interpretation 
BF< a Decisive evidence for Mo 
BF < i Strong evidence for Mo 
b < BF < 3 Moderate evidence for Mo 
ea BP <1 Weak evidence for Mo 
1<BF <3 Weak evidence for Mı 
3< BF < 10 Moderate evidence for Mı 
BF > 10 Strong evidence for Mı 
BF > 100 Decisive evidence for Mı 


Table 5.6: Jeffreys scale of evidence for interpreting Bayes factors. 


If B, 9 > 1 then we prefer model 1, otherwise we prefer model 0. Of course, it might be that By 
is only slightly greater than 1. In that case, we are not very confident that model 1 is better. Jeffreys 
[Jef61] proposed a scale of evidence for interpreting the magnitude of a Bayes factor, which is shown 
in Table 5.6. This is a Bayesian alternative to the frequentist concept of a p-value (see Section 5.5.3). 

We give a worked example of how to compute Bayes factors in Section 5.2.1.1. 


5.2.1.1 Example: Testing if a coin is fair 


As an example, suppose we observe some coin tosses, and want to decide if the data was generated by 
a fair coin, 0 = 0.5, or a potentially biased coin, where 6 could be any value in [0,1]. Let us denote 
the first model by Mo and the second model by Mı. The marginal likelihood under Mo is simply 


p(D|Mo) = (3) (5.46) 


where N is the number of coin tosses. From Equation (4.143), the marginal likelihood under M,, 
using a Beta prior, is 
Bay + Ny, ao + No) 


p(PIM:) = | pDl = ee. (5.47) 


We plot log p(D|M1) vs the number of heads N; in Figure 5.4(a), assuming N = 5 and a uniform 
prior, a] = a = 1. (The shape of the curve is not very sensitive to a; and ao, as long as the 
prior is symmetric, so ag = a1.) If we observe 2 or 3 heads, the unbiased coin hypothesis Mo 
is more likely than Mı, since Mo is a simpler model (it has no free parameters) — it would be 
a suspicious coincidence if the coin were biased but happened to produce almost exactly 50/50 
heads/tails. However, as the counts become more extreme, we favor the biased coin hypothesis. Note 
that, if we plot the log Bayes factor, log B1,0, it will have exactly the same shape, since log p(D|Mo) 
is a constant. 


5.2.2 Bayesian model selection 


Now suppose we have a set M of more than 2 models, and we want to pick the most likely. This 
is called model selection. We can view this as a decision theory problem, where the action space 
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Figure 5.4: (a) Log marginal likelihood vs number of heads for the coin tossing example. (b) BIC approximation. 
(The vertical scale is arbitrary, since we are holding N fixed.) Generated by coins _model_sel_ demo.ipynb. 


requires choosing one model, m € M. If we have a 0-1 loss, the optimal action is to pick the most 
probable model: 


m = argmax p(m|D) (5.48) 
meM 
where 
p(m|D) = p\P|m)p(m) (5.49) 


mem P(D|m)p(m) 


is the posterior over models. If the prior over models is uniform, p(m) = 1/|M|, then the MAP 
model is given by 


m = argmax p(D|m) (5.50) 
mEM 


The quantity p(D|m) is given by 
p(Dlm) = f p(D|8,m)p(6|m)a9 (5.51) 
This is known as the marginal likelihood, or the evidence for model m. Intuitively, it is the 


likelihood of the data averaged over all possible parameter values, weighted by the prior p(@|m). If 
all settings of @ assign high probability to the data, then this is probably a good model. 


5.2.2.1 Example: polynomial regression 


As an example of Bayesian model selection, we will consider polynomial regression in 1d. Figure 5.5 
shows the posterior over three different models, corresponding to polynomials of degrees 1, 2 and 3 fit 
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deg=1 | logev=-16.28 deg=2 | logev=-20.64 


(a) (b) 


deg=3 | logev=-24.95 


0.0 ——— 
2 


(c) (a) 


Figure 5.5: Ilustration of Bayesian model selection for polynomial regression. (a-c) We fit polynomials of 
degrees 1, 2 and 3 fit to N = 5 data points. The solid green curve is the true function, the dashed red curve 
is the prediction (dotted blue lines represent +20 around the mean). (d) We plot the posterior over models, 
p(m|D), assuming a uniform prior p(m) x 1. Generated by linreg_eb_modelsel_vs_n.ipynb. 


to N = 5 data points. We use a uniform prior over models, and use empirical Bayes to estimate the 
prior over the regression weights (see Section 11.7.7). We then compute the evidence for each model 
(see Section 11.7 for details on how to do this). We see that there is not enough data to justify a 
complex model, so the MAP model is m = 1. Figure 5.6 shows the analogous plot for N = 30 data 
points. Now we see that the MAP model is m = 2; the larger sample size means we can safely pick a 
more complex model. 


5.2.3 Occam’s razor 


Consider two models, a simple one, m1, and a more complex one, m2. Suppose that both can explain 
the data by suitably optimizing their parameters, i.e., for which p(D\61, mı) and p(D|O2, mg) are 
both large. Intuitively we should prefer mı, since it is simpler and just as good as mg. This principle 
is known as Occam’s razor. 

Let us now see how ranking models based on their marginal likelihood, which involves averaging 
the likelihood wrt the prior, will give rise to this behavior. The complex model will put less prior 
probability on the “good” parameters that explain the data, 62, since the prior must integrate to 
1.0 over the entire parameter space. Thus it will take averages in parts of parameter space with 
low likelihood. By contrast, the simpler model has fewer parameters, so the prior is concentrated 
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deg=1 | logev=-146.15 deg=2 | logev=-72.37 


(a) (b) 


deg=3 | logev=-77.22 


zN 


(c) (d) 


Figure 5.6: Same as Figure 5.5 except now N = 30. Generated by linreg_ eb_ modelsel_vs_n.ipynb. 


over a smaller volume; thus its averages will mostly be in the good part of parameter space, near ô. 
Hence we see that the marginal likelihood will prefer the simpler model. This is called the Bayesian 
Occam’s razor effect [Mac95; MGO5]. 

Another way to understand the Bayesian Occam’s razor effect is to compare the relative predictive 
abilities of simple and complex models. Since probabilities must sum to one, we have X `p, p(D’|m) = 1, 
where the sum is over all possible datasets. Complex models, which can predict many things, must 
spread their predicted probability mass thinly, and hence will not obtain as large a probability for 
any given data set as simpler models. This is sometimes called the conservation of probability 
mass principle, and is illustrated in Figure 5.7. On the horizontal axis we plot all possible data sets 
in order of increasing complexity (measured in some abstract sense). On the vertical axis we plot the 
predictions of 3 possible models: a simple one, M1; a medium one, M2; and a complex one, M3. We 
also indicate the actually observed data Do by a vertical line. Model 1 is too simple and assigns low 
probability to Dp. Model 3 also assigns Dp relatively low probability, because it can predict many 
data sets, and hence it spreads its probability quite widely and thinly. Model 2 is “just right”: it 
predicts the observed data with a reasonable degree of confidence, but does not predict too many 
other things. Hence model 2 is the most probable model. 
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p(D) 


Figure 5.7: A schematic illustration of the Bayesian Occam’s razor. The broad (green) curve corresponds to a 
complex model, the narrow (blue) curve to a simple model, and the middle (red) curve is just right. Adapted 
from Figure 3.13 of [Bis06]. See also [MG05, Figure 2] for a similar plot produced on real data. 


5.2.4 Connection between cross validation and marginal likelihood 


We have seen how the marginal likelihood helps us choose models of the “right” complexity. In 
non-Bayesian approaches to model selection, it is standard to use cross validation (Section 4.5.5) for 
this purpose. 

It turns out that the marginal likelihood is closely related to the leave-one-out cross-validation 
(LOO-CV) estimate, as we now show. We start with the marginal likelihood for model m, which we 
write in sequential form as follows: 


N N 
p(D\|m) = II PlYnlYr:n-1; L1:N, m) = II P(YnlEn; Dim-1; m) (5.52) 
n=1 n=1 

where 

pyle, Punim) = | ply, )p(6|Din-1,m)a8 (5.58) 
Suppose we use a plugin approximation to the above distribution to get 

p(yle,Disn-1,m) © | ply, 8)5(8 ~ Bm(Piin-s))d8 = pll, Pm (Pin) (5.54) 
Then we get 

N A 
log p(D|m) ~ X` log p(yn|£n, Om (Din—1)) (5.55) 
n= 


This is similar to a leave-one-out cross-validation estimate of the likelihood, which has the form 
x pie log p(Yn|n, Gin Din ams); except we ignore the Dn+1:y part. The intuition behind 
the connection is this: an overly complex model will overfit the “early” examples and will then predict 
the remaining ones poorly, and thus will also get a low cross-validation score. See [FH20] for a more 
detailed discussion of the connection between these performance metrics. 
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5.2.5 Information criteria 


The marginal likelihood, p(D|m) = f p(D|@,m)p(@)d@, which is needed for Bayesian model selection 
discussed in Section 5.2.2, can be difficult to compute, since it requires marginalizing over the entire 
parameter space. Furthermore, the result can be quite sensitive to the choice of prior. In this section, 
we discuss some other related metrics for model selection known as information criteria. We only 
give a brief discussion; see e.g., [GHV14] for further details. 


5.2.5.1 The Bayesian information criterion (BIC) 


The Bayesian information criterion or BIC [Sch78] can be thought of as a simple approximation 
to the log marginal likelihood. In particular, if we make a Gaussian approximation to the posterior, 
as discussed in Section 4.6.8.2, we get (from Equation (4.215)) the following: 


. ‘ 1 
log p(D|m) ~ log p(D|O@map) + log p(Omap) — z los H| (5.56) 


where H is the Hessian of the negative log joint, — log p(D, 0), evaluated at the MAP estimate Ô map- 
We see that Equation (5.56) is the log likelihood plus some penalty terms. If we have a uniform prior, 
p(@) œ 1, we can drop the prior term, and replace the MAP estimate with the MLE, 0, yielding 


A 1 
log p(D|m) ~ log p(D|@) — > log |H| (5.57) 


We now focus on approximating the log |H]| term, which is sometimes called the Occam factor, 
since it is a measure of model complexity (volume of the posterior distribution). We have H = 
G H;, where H; = VV log p(D;|0). Let us approximate each H; by a fixed matrix H. Then we 
have 


log |H| = log |NH| = log(N?|H]|) = D log N + log |H| (5.58) 


where D = dim(0) and we have assumed H is full rank. We can drop the log |H| term, since it is 
independent of N, and thus will get overwhelmed by the likelihood. Putting all the pieces together, 
we get the BIC score that we want to maximize: 


Jerom) = log p(D|m) ~ log p(D|8,m) — 2 log N (5.59) 
We can also define the BIC loss, that we want to minimize, by multiplying by -2: 

Lpic(m) = —2log p(D|@,m) + Dm log N (5.60) 
(The use of 2 as a scale factor is chosen to simplify the expression when using a model with a Gaussian 
likelihood.) 
5.2.5.2 Akaike information criterion 
The Akaike information criterion [Aka74] is closely related to the BIC. It has the form 

Latc(m) = —2log p(D|@,m) + 2D (5.61) 
This penalizes complex models less heavily than BIC, since the regularization term is independent of 


N. This estimator can be derived from a frequentist perspective. 
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5.2.5.3 Minimum description length (MDL) 


We can think about the problem of scoring different models in terms of information theory (Chapter 6). 
The goal is for the sender to communicate the data to the receiver. First the sender needs to specify 
which model m to use; this takes C(m) = — log p(m) bits (see Section 6.1). Then the receiver can 
fit the model, by computing Ôm, and can thus approximately reconstruct the data. To perfectly 
reconstruct the data, the sender needs to send the residual errors that cannot be explained by the 
model; this takes —L(m) = — log p(D|@,m) = — DA log p(yn|Î, m) bits. The total cost is 


Lupi(m) = — log p(D|Â, m) + C(m) (5.62) 


We see that this has the same basic form as BIC/AIC. Choosing the model which minimizes this 
objective is known as the minimum description length or MDL principle. See e.g., [HY01b] for 
details. 


5.2.6 Posterior inference over effect sizes and Bayesian significance testing 


The approach to hypothesis testing discussed in Section 5.2.1 relies on computing the Bayes factors 
for the null vs the alternative model, p(D|Ho)/p(P|H1). Unfortunately, computing the necessary 
marginal likelihoods can be computationally difficult, and the results can be sensitive to the choice of 
prior. Furthermore, we are often more interested in estimating an effect size, which is the difference 
in magnitude between two parameters, rather than in deciding if an effect size is 0 (null hypothesis) 
or not (alternative hypothesis) — the latter is called a point null hypothesis, and is often regarded 
as an irrelevant “straw man” (see e.g., [Mak-+19] and references therein). 

For example, suppose we have two classifiers, mı and mz, and we want to know which one is 
better. That is, we want to perform a comparison of classifiers. Let 4; and u2 be their average 
accuracies, and let A = p1 — u2 be the difference in their accuracies. The probability that model 1 is 
more accurate, on average, than model 2 is given by p(A > 0|D). However, even if this probability is 
large, the improvement may be not be practically significant. So it is better to compute a probability 
such as p(A > e|D) or p(|A| > e|D), where e represents the minimal magnitude of effect size that is 
meaningful for the problem at hand. This is called a one-sided test or two-sided test. 

More generally, let R = |—€, €] represent a region of practical equivalence or ROPE [Kru15; 
KL17]. We can define 3 events of interest: the null hypothesis Hp : A € R, which says both methods 
are practically the same (which is a more realistic assumption than Hp : A = 0); H4 : A>, which 
says Mı is better than m2; and Hg: A < —e, which says mz is better than mı. To choose amongst 
these 3 hypotheses, we just have to compute p(A|D), which avoids the need to compute Bayes factors. 
In the sections below, we discuss how to compute this quantity using two different kinds of model. 


5.2.6.1 Bayesian t-test for difference in means 


Suppose we have two classifiers, mı and m2, which are evaluated on the same set of N test examples. 
Let e; be the error of method m on test example i. (Or this could be the conditional log likelihood, 
e? = log p” (y;|a@;).) Since the classifiers are applied to the same data, we can use a paired test for 
comparing them, which is more sensitive than looking at average performance, since the factors that 
make one example easy or hard to classify (e.g., due to label noise) will be shared by both methods. 
Thus we will work with the differences, d; = e} — e?. We assume d; ~ N(A,o7). We are interested 
in p(Ald), where d = (dj,...,dy). 
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If we use an uninformative prior for the unknown parameters (A,o), one can show that the 
posterior marginal for the mean is given by a Student distribution: 


p(Ald) = Tn- (Alu, s°/N) 


where u = +37’, di is the sample mean, and s? = yh ~™,(di — p)? is an unbiased estimate 
of the variance. Hence we can easily compute p(|A| > e|d), with a ROPE of e = 0.01 (say). This 
is known as a Bayesian t-test [Ben+17]. (See also [Rou+09] for Bayesian t-test based on Bayes 
factors, and [Die98] for a non-Bayesian approach to comparing classifiers.) 

An alternative to a formal test is to just plot the posterior p(A|d). If this distribution is tightly 
centered on 0, we can conclude that there is no significant difference between the methods. (In fact, 
an even simpler approach is to just make a boxplot of the data, {d;}, which avoids the need for any 
formal statistical analysis.) 

Note that this kind of problem arises in many applications, not just evaluating classifiers. For 
example, suppose we have a set of N people, each of whom is exposed two drugs; let e} be the 
outcome (e.g., sickness level) when person i is exposed to drug m, and let d™ = e} — e? be the 
difference in response. We can then analyse the effect of the drug by computing p(A|d) as we 
discussed above. 


5.2.6.2 Bayesian x?-test for difference in rates 


Now suppose we have two classifiers which are evaluated on different test sets. Let ym be the 
number of correct examples from method m € {1,2} out of Nm trials, so the accuracy rate is 
Ym/Nm- We assume ym ~ Bin( Nm, 9m), so we are interested in p(A|D), where A = 6, — 62, and 
D = (y1, Ni, Y2, N2) is all the data. 

If we use a uniform prior for 6, and 62 (i.e., p(@;) = Beta(@;|1,1)), the posterior is given by 


pl, b2|D) = Beta(1|y1 + 1, Ni — yı + 1)Beta(ð2|y2 + 1, No YD + 1) 


The posterior for A is given by 
ipi 
WAD) = f f 1A = 6, ~ 62) POD) POD) 
o Jo 
1 


= 1 Beta(@1|y1 + 1, Nı — yı + 1)Beta(@, = Alyse + 1, No — Y2 + 1)d6, 
0 


We can then evaluate this for any value of A that we choose. For example, we can compute 
p(A > €|D) = f p(A|D) dA (5.63) 


(We can compute this using 1 dimensional numerical integration or analytically [Coo05].) This is 
called a Bayesian \?-test. 

Note that this kind of problem arises in many applications, not just evaluating classifiers, For 
example, suppose the two groups are different companies selling the same product on Amazon, and 
Ym is the number of positive reviews for merchant m. Or suppose the two groups correspond to men 
and women, and ym is the number of people in group m who are left handed, and Nm — Ym to be 
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LH RH 
Male 9 43 Nı = 52 
Female | 4 44 Na = 48 
Totals | 13 87 100 


Table 5.7: A 2 x 2 contingency table from http: //en. wikipedia. org/ wiki/ Contingency table. The 
MLEs for the left handedness rate in males and females are 6, = 9/52 = 0.1731 and 62 = 4/48 = 0.0417. 


the number who are right handed. We can represent the data as a 2 x 2 contingency table of 
counts, as shown in Table 5.7. 

The MLEs for the left handedness rate in males and females are 6; = 9/52 = 0.1731 and 
ĝa = 4/48 = 0.0417. It seems that there is a difference, but the sample size is low, so we cannot be 
sure. Hence we will represent our uncertainty by computing p(A|D), where A = 6; — 02 and D is the 
table of counts. We find p(@, > 62|D) = i 8 (A|D) = 0.901, which suggests that left handedness is 
more common in males, consistent with ne studies [PP+20]. 


5.3 Frequentist decision theory 


In this section, we discuss frequentist decision theory. In this approach, we treat the unknown 
state of nature (often denoted by @ instead of h) as a fixed but unknown quantity, and we treat the 
data x as random. Thus instead of conditioning on x, we average over it, to compute the loss we 
expect to incur if we apply our decision procedure (estimator) to many different datasets. We give 
the details below. 


5.3.1 Computing the risk of an estimator 


We define the frequentist risk of an estimator 6 given an unknown state of nature @ to be the expected 
loss when applying that estimator to data x, where the expectation is over the data, sampled from 
p(a|9): 


R(0, ô) = Exo) [€(9, 5())] (5.64) 


We give an example of this in Section 5.3.1.1. 


5.3.1.1 Example 


In this section, we consider the problem of estimating the mean of a Gaussian. We assume the 
data is sampled from £n ~ N (0*, o? = 1), and we let x = (x1,...,vN). If we use quadratic loss, 
¢2(0,0) = (8 — 0)”, the corresponding risk function is the MSE. 

We now consider 5 different estimators for computing 6: 


e ô (x) = 7, the sample mean. 
e (x) = median(ax), the sample median. 


3. This example is based on the following blog post by Bob Carpenter: https://bit.ly/2FykD1C. 
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Figure 5.8: Risk functions for estimating the mean of a Gaussian. Each curve represents R(6;(-),") plotted 
vs 0*, where i indexes the estimator. Each estimator is applied to N samples from N (0*, o = 1). The dark 
blue horizontal line is the sample mean (MLE); the red line horizontal line is the sample median; the black 
curved line is the estimator 6 = 09 = 0; the green curved line is the posterior mean when k = 1; the light blue 
curved line is the posterior mean when k = 5. (a) N =5 samples. (b) N = 20 samples. Adapted from Figure 
B.1 of [BS94]. Generated by riskFnGauss.ipynb. 


e 63(a) = ĝo, a fixed value 
e (x), the posterior mean under a N (0|00, 07/k) prior: 


N 7 K = 
ae N nd = We + (1 w) (5.65) 


ôn (a) 


For ôx „we use 0o = 0, and consider a weak prior, x = 1, and a stronger prior, « = 5. 
Let 0 = 0(x) = 6(a) be the estimated parameter. The risk of this estimator is given by the MSE. 
In Section 4.7.6.3, we show that the MSE can be decomposed into squared bias plus variance: 


MSE(6|0*) = V [é] + bias? (ô) (5.66) 


where the bias is defined as bias(0) = [ô — o: We now use this expression to derive the risk for 


each estimator. 
6, is the sample mean. This is unbiased, so its risk is 


MSE(6,|*) = V [z] = = (5.67) 


62 is the sample median. This is also unbiased. Furthermore, one can show that its variance is 
approximately 7/(2N) (where m = 3.14) so the risk is 


MSE(ô2|0*) = a (5.68) 
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d3 returns the constant ĝo, so its bias is (0* — ĝo) and its variance is zero. Hence the risk is 


MSE(63|0*) = (0* — 00)” (5.69) 
Finally, 64 is the posterior mean under a Gaussian prior. We can derive its MSE as follows: 
MSE(5,|0*) = E [(wz + (1 — w) — *)?| (5.70) 

= [we - 6") + (1 = w) (8o = 6*))?| (5.71) 

a 
a w + (1 — w)? (0o — 0*)? (5.72) 
1 D3. 410 #\2 
= (Van? (No? + K? (0o — 6*)*) (5.73) 


These functions are plotted in Figure 5.8 for N € {5,20}. We see that in general, the best estimator 
depends on the value of 6*, which is unknown. If 6* is very close to 69, then 63 (which just predicts 
0o) is best. If 6* is within some reasonable range around 60, then the posterior mean, which combines 
the prior guess of ĝo with the actual data, is best. If 0* is far from 69, the MLE is best. 


5.3.1.2 Bayes risk 


In general, the true state of nature 0 that generates the data æ is unknown, so we cannot compute 
the risk given in Equation (5.64). One solution to this is to assume a prior 7 for 0, and then average 
it out. This gives us the Bayes risk, also called the integrated risk: 


Rao (8) = Exo(a) [R(0, ô)] = fæ dæ ro(0)p(æ|0)L(0, 6(a)) (5.74) 


A decision rule that minimizes the Bayes risk is known as a Bayes estimator. This is equivalent to 
the optimal policy recommended by Bayesian decision theory in Equation (5.2) since 


a 


6(x) = argmin f d0 To(O)p(x|0ML(0, a) = argmin f d0 p(O|x)£(0, a) (5.75) 


Hence we see that picking the optimal action on a case-by-case basis (as in the Bayesian approach) is 
optimal on average (as in the frequentist approach). In other words, the Bayesian approach provides 
a good way of achieving frequentist goals. See [BS94, p448] for further discussion of this point. 
5.3.1.3 Maximum risk 


Of course the use of a prior might seem undesirable in the context of frequentist statistics. We can 
therefore define the maximum risk as follows: 


Rmax(ô) = sup R(90, ô) (5.76) 

0 
A decision rule that minimizes the maximum risk is called a minimax estimator, and is denoted 
ômm. For example, in Figure 5.9, we see that 6; has lower worst-case risk than 62, ranging over all 


possible values of 0, so it is the minimax estimator. 
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ox R(0, mı) -> 
STe R(0, r2) 


Figure 5.9: Risk functions for two decision procedures, 61 and 62. Since 6, has lower worst case risk, it is the 
minimax estimator, even though 62 has lower risk for most values of 0. Thus minimaz estimators are overly 
conservative. 


Minimax estimators have a certain appeal. However, computing them can be hard. And furthermore, 
they are very pessimistic. In fact, one can show that all minimax estimators are equivalent to Bayes 
estimators under a least favorable prior. In most statistical situations (excluding game theoretic 
ones), assuming nature is an adversary is not a reasonable assumption. 


5.3.2 Consistent estimators 


Suppose we have a dataset x = {£n : n =1: N} where the samples £n € ¥ are generated iid from a 
distribution p(X|6*), where 0* € © is the true parameter. Furthermore, suppose the parameters 
are identifiable, meaning that p(x|@) = p(x|0') iff 0 = 0’ for any dataset æ. Then we say that an 
estimator ô : XN — © is a consistent estimator if @(a) — 0* as N — oo (where the arrow denotes 
convergence in probability). In other words, the procedure 6 recovers the true parameter (or a subset 
of it) in the limit of infinite data. This is equivalent to minimizing the 0-1 loss, £(6*,0) = I (o # ô). 
An example of a consistent estimator is the maximum likelihood estimator (MLE). 

Note that an estimator can be unbiased but not consistent. For example, consider the estimator 
O(a) = 6({a1,...,an}) = £y. This is an unbiased estimator of the true mean p, since E [ô(æ)] = 
E [æn] = u. But the sampling distribution of ô(æ) does not converge to a fixed value, so it cannot 
converge to the point 0*. 

Although consistency is a desirable property, it is of somewhat limited usefulness in practice since 
most real datasets do not come from our chosen model family (i.e., there is no 8* such that p(-|6*) 
generates the observed data x). In practice, it is more useful to find estimators that minimize some 
discrepancy measure between the empirical distribution and the estimated distribution. If we use KL 
divergence as our discrepancy measure, our estimate becomes the MLE. 


5.3.3 Admissible estimators 


We say that 6; dominates ô> if R(@,6,) < R(@, 62) for all 8. The domination is said to be strict 
if the inequality is strict for some 0*. An estimator is said to be admissible if it is not strictly 
dominated by any other estimator. Interestingly, [Wal47] proved that all admissible decision rules 
are equivalent to some kind of Bayesian decision rule, under some technical conditions. (See [DR21] 
for a more general version of this result.) 

For example, in Figure 5.8, we see that the sample median (dotted red line) always has higher risk 
than the sample mean (solid blue line). Therefore the sample median is not an admissible estimator 
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for the mean. More surprisingly, one can show that the sample mean is not always an admissible 
estimator either, even under a Gaussian likelihood model with squared error loss (this is known as 
Stein’s paradox [Ste56]). 

However, the concept of admissibility is of somewhat limited value. For example, let X ~ N (0,1), 
and consider estimating 0 under squared loss. Consider the estimator 6,(a”) = ĝo, where ĝo is a 
constant independent of the data. We now show that this is an admissible estimator. 

To see this, suppose it were not true. Then there would be some other estimator ô> with smaller 
risk, so R(0*, 62) < R(6*,61), where the inequality must be strict for some 0*. Consider the risk at 
0* = 0o. We have R(6, 61) = 0, and 


R(6 0.62) = J (62(2) — 60)?p(2|40)dex (5.77) 


Since 0 < R(0*,ô2) < R(O*, 61) for all 6*, and R(@0,61) = 0, we have R(0o,ô2) = 0 and hence 
62(x) = 09 = 61(#). Thus the only way 62 can avoid having higher risk than 6; at ĝo is by being 
equal to 6,. Hence there is no other estimator 62 with strictly lower risk, so ô is admissible. 

Thus we see that the estimator 6)(x) = 09 is admissible, even though it ignores the data, so 
is useless as an estimator. Conversely, it is possible to construct useful estimators that are not 
admissable (see e.g., [Jay03, Sec 13.7]). 


5.4 Empirical risk minimization 


In this section, we consider how to apply frequentist decision theory in the context of supervised 
learning. 


5.4.1 Empirical risk 


In standard accounts of frequentist decision theory used in statistics textbooks, there is a single 
unknown “state of nature”, corresponding to the unknown parameters 0* of some model, and we 
define the risk as in Equation (5.64), namely R(d, 0*) = E,cpje*) [¢(8*, 6(D))}. 

In supervised learning, we have a different unknown state of nature (namely the output y) for each 
input x, and our estimator 6 is a prediction function y = f(x), and the state of nature is the true 
distribution p*(a,y). Thus the risk of an estimator is as follows: 


This is called the population risk, since the expectations are taken wrt the true joint distribution 
p*(a,y). Of course, p* is unknown, but we can approximate it using the empirical distribution with 
N samples: 


1 
pplz, y|D) £ D NO 5(@—an)5(y — Yn) (5.79) 
(En, Yn) ED 


where pp(x,y) = Pir(x, y). Plugging this in gives us the empirical risk: 


N 
R(F,D) Ê Enote) Uu, Pœ] = 5 Dns Fær) (5.80) 
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Note that R(f,D) is a random variable, since it depends on the training set. 
A natural way to choose the predictor is to use 


N 
x J 
ferm = argmin R( f, D) = argmin — X Lyn, f(#n)) (5.81) 
FEH fen N 


where we optimize over a specific hypothesis space H of functions. This is called empirical risk 
minimization (ERM). 


5.4.1.1 Approximation error vs estimation error 


In this section, we analyze the theoretical performance of functions that are fit using the ERM 
principle. Let f** = argminy R(f) be the function that achieves the minimal possible population risk, 
where we optimize over all possible functions. Of course, we cannot consider all possible functions, 
so let us also define f* = argminsey, R(f) to be the best function in our hypothesis space, H. 
Unfortunately we cannot compute f*, since we cannot compute the population risk, so let us finally 
define the prediction function that minimizes the empirical risk in our hypothesis space: 


fy = argmin R(f,D) = argminE,,, Ely, f(x))] (5.82) 
FEH fEH 


One can show [BB08] that the risk of our chosen predictor compared to the best possible predictor 
can be decomposed into two terms, as follows: 


Opt [R(fy) — ROU) = RG) — ROP) + Ep [RUfy) — RP) (5.83) 
aes OTe 
Eapp(H) Eest (H,N) 


The first term, Eapp(H), is the approximation error, which measures how closely H can model the 
true optimal function f**. The second term, Eest(H, N), is the estimation error or generalization 
error, which measures the difference in estimated risks due to having a finite training set. We can 
approximate this by the difference between the training set error and the test set error, using two 
empirical distributions drawn from p*: 


Opt [R(fn) — ROF) © Ep. (Cy, fn (@))] — Ep. Ely, fv (@))] (5.84) 


This difference is often called the generalization gap. 

We can decrease the approximation error by using a more expressive family of functions H, but 
this usually increases the generalization error, due to overfitting. We discuss solutions to this tradeoff 
below. 


5.4.1.2 Regularized risk 


To avoid the chance of overfitting, it is common to add a complexity penalty to the objective function, 
giving us the regularized empirical risk: 


Ry(f,D) = RF, D) + AC(f) (5.85) 
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where C(f) measures the complexity of the prediction function f(x;0), and A > 0, which is known 
as a hyperparameter, controls the strength of the complexity penalty. (We discuss how to pick A 
in Section 5.4.2.) 

In practice, we usually work with parametric functions, and apply the regularizer to the parameters 
themselves. This yields the following form of the objective: 


R,(0,D) = R(0,D) + AC(0) (5.86) 


Note that, if the loss function is log loss, and the regularizer is a negative log prior, the regularized 
risk is given by 


N 
1 
RO, D) =- Dlg p(n len, 9) — Alog p(8) (5.87) 
Minimizing this is equivalent to MAP estimation. 


5.4.2 Structural risk 


A natural way to estimate the hyperparameters is to minimize for the lowest achievable empirical 
risk: 


A = argmin min R)(0,D) (5.88) 
A 


(This is an example of bilevel optimization, also called nested optimization.) Unfortunately, 
this technique will not work, since it will always pick the least amount of regularization, i.e., A = 0. 
To see this, note that 


argmin min R, (0, D) = argmin min R(O,D) +àC(0) (5.89) 
A A 


which is minimized by setting A = 0. The problem is that the empirical risk underestimates the 
population risk, resulting in overfitting when we choose A. This is called optimism of the training 
error. 

If we knew the regularized population risk R) (0), instead of the regularized empirical risk R\ (0, D), 
we could use it to pick a model of the right complexity (e.g., value of A). This is known as structural 
risk minimization [Vap98]. There are two main ways to estimate the population risk for a 
given model (value of A), namely cross-validation (Section 5.4.3), and statistical learning theory 
(Section 5.4.4), which we discuss below. 


5.4.3 Cross-validation 


In this section, we discuss a simple way to estimate the population risk for a supervised learning 
setup. We simply partition the dataset into two, the part used for training the model, and a second 
part, called the validation set or holdout set, used for assessing the risk. We can fit the model on 
the training set, and use its performance on the validation set as an approximation to the population 
risk. 
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To explain the method in more detail, we need some notation. First we make the dependence of 
the empirical risk on the dataset more explicit as follows: 


RA(0,D)= j $ tw F(@:8)) + AC) (5.90) 
(æx,y)ED 


Let us also define 0, (D) = argming R)(D,6). Finally, let Dtrain and Dyaiiq be a partition of D. 
(Often we use about 80% of the data for the training set, and 20% for the validation set.) 

For each model A, we fit it to the training set to get aN (Dtrain). We then use the unregularized 
empirical risk on the validation set as an estimate of the population risk. This is known as the 
validation risk: 


Re = Ro (ôx (Diran); Dyalid) (5.91) 


Note that we use different data to train and evaluate the model. 

The above technique can work very well. However, if the number of training cases is small, this 
technique runs into problems, because the model won’t have enough data to train on, and we won’t 
have enough data to make a reliable estimate of the future performance. 

A simple but popular solution to this is to use cross validation (CV). The idea is as follows: we 
split the training data into K folds; then, for each fold k € {1,..., K}, we train on all the folds but 
the k’th, and test on the k’th, in a round-robin fashion, as sketched in Figure 4.6. Formally, we have 


K 


RY & = X Ro(O\(D-x); Pr) (5.92) 
k=1 


where Dy; is the data in the k’th fold, and D_, is all the other data. This is called the cross-validated 
risk. Figure 4.6 illustrates this procedure for K = 5. If we set K = N, we get a method known as 
leave-one-out cross-validation, since we always train on N — 1 items and test on the remaining 
one. 

We can use the CV estimate as an objective inside of an optimization routine to pick the optimal 
hyperparameter, \= argmin, RS’. Finally we combine all the available data (training and validation), 
and re-estimate the model parameters using Ê = argming R;(0,D). 


5.4.4 Statistical learning theory * 


The principal problem with cross validation is that it is slow, since we have to fit the model multiple 
times. This motivates the desire to compute analytic approximations or bounds on the population 
risk. This is studied in the field of statistical learning theory (SLT) (see e.g., [Vap98]). 

More precisely, the goal of SLT is to upper bound the generalization error with a certain probability. 
If the bound is satisfied, then we can be confident that a hypothesis that is chosen by minimizing 
the empirical risk will have low population risk. In the case of binary classifiers, this means the 
hypothesis will make the correct predictions; in this case we say it is probably approximately 
correct, and that the hypothesis class is PAC learnable (see e.g., [KV94] for details). 
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5.4.4.1 Bounding the generalization error 


In this section, we establish conditions under which we can prove that a hypothesis class is PAC 
learnable. Let us initially consider the case where the hypothesis space is finite, with size dim(H) = |H]. 
In other words, we are selecting a hypothesis from a finite list, rather than optimizing real-valued 
parameters. In this case, we can prove the following. 


Theorem 5.4.1. For any data distribution p*, and any dataset D of size N drawn from p*, the 
probability that the generalization error of a binary classifier will be more than €, in the worst case, 
is upper bounded as follows: 


P (pa |R(h) — R(h, D)| > e) < 2dim(H)e2N© (5.93) 


where R(h, D) = 4 DAM I(f(x:) #yž) is the empirical risk, and R(h) = E [I (f(x) 4 y*)] is the 


population risk. 
Proof. Before we prove this, we introduce two useful results. First, Hoeffding’s inequality, which 
states that if Æ1,..., En ~ Ber(0), then, for any € > 0, 

P(E = 6| > €) < 2077n? (5.94) 


where E = + G E; is the empirical error rate, and @ is the true error rate. Second, the union 
bound, which says that if A1,..., Aq are a set of events, then P(U%, A;) < S P(A;). Using 
these results, we have 


hEH 


P (pa |R(h) — R(h, D)| > e) =P ( LJ |R(h) -— R(h, D)| > e) (5.95) 


hEH 
< $ P(IR(h) — R(h, D)| > 6) (5.96) 
hEH 
<y ni = 2dim(H)e“ 24" (5.97) 
hEH 


This bound tells us that the optimism of the training error increases with dim(H) but decreases 
with N = |D], as is to be expected. 


5.4.4.2 VC dimension 


If the hypothesis space H is infinite (e.g., we have real-valued parameters), we cannot use dim(H) = 
|H|. Instead, we can use a quantity called the VC dimension of the hypothesis class, named after 
Vapnik and Chervonenkis; this measures the degrees of freedom (effective number of parameters) of 
the hypothesis class. See e.g., [Vap98] for the details. 

Unfortunately, it is hard to compute the VC dimension for many interesting models, and the upper 
bounds are usually very loose, making this approach of limited practical value. However, various 
other, more practical, estimates of generalization error have recently been devised, especially for 
DNNs, such as [Jia+20]. 
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Figure 5.10: (a) Illustration of the Neyman-Pearson hypothesis testing paradigm. Generated by neyman- 
Pearson2.ipynb. (b) Two hypothetical two-sided power curves. B dominates A. Adapted from Figure 6.3.5 of 
[LM86]. Generated by twoPowerCurves.ipynb. 


5.5 Frequentist hypothesis testing * 


In this section, we discuss ways to determining if a hypothesis (model) is plausible or not, in the 
light of data D. 


5.5.1 Likelihood ratio test 


When deciding if a model is a good description of some data or not, it is always useful to ask “relative 
to what”. To make this concrete, suppose we have two hypotheses, known as the null hypothesis Ho 
and an alternative hypothesis H,, and we want to choose the one we think is more likely. We can 
think of this as a binary classification problem, where H € {0,1} represents the identity of the “true’ 
model. A natural approach is to use Bayesian model selection, as we discussed in Section 5.2.1, to 
compute p(H|D), and then to pick the most probable model. Here we discuss a frequentist approach. 

Suppose we have a uniform prior, so p(H = 0) = p(H = 1) = 0.5, and that we use 0-1 loss. Then 


the optimal decision rule is to accept Ho iff : a > 1. This is called the likelihood ratio test. 


? 


We give some examples of this below. 


5.5.1.1 Example: comparing Gaussian means 


Suppose we are interested in testing whether some data comes from a Gaussian with mean po or 
from a Gaussian with mean u1. (We assume a known shared variance o?.) This is illustrated in 
Figure 5.10a, where we plot p(xz|Ho) and p(x|H,). We can derive the likelihood ratio as follows: 


p(D|Ho) P (=): EC ae y10)?) 


= 5.98 
DIDI) ~ exp (abe Sha len — tn?) = 
= exp (5 (@NE(u0 — as) + Nad — N18) ) (5.99) 


We see that this ratio only depends on the observed data via its mean, % From Figure 5.10a, 


aaa > 1 iff Z < x*, where 2* is the point where the two pdf’s intersect (we are 


assuming this point is unique). 


we can see that 
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5.5.1.2 Simple vs compound hypotheses 


In Section 5.5.1.1, the parameters for the null and alternative hypotheses were either fully specified 
(uo and u1) or shared (c?). This is called a simple hypothesis test. In general, a hypothesis might 
not fully specify all the parameters; this is called a compound hypothesis. In this case, we could 
integrate out these unknown parameters, as in the Bayesian approach, since a hypothesis with more 
parameters will always have higher likelihood. However, this can be computationally difficult, and is 
prone to problems caused by prior misspecification. As an alternative approach, we can “maximize 
out” the parameters, which gives us the maximum likelihood ratio test: 


p(Ho|P) _ loct p(0)po(D) ay Maxge Ho pe(D) 


p(Ai|D) — Joer, p(@)pe(D) — maxoen, po(D) (5.100) 


5.5.2 Type I vs type II errors and the Neyman-Pearson lemma 


Hypothesis testing is a kind of binary classification problem. As we discussed in Section 5.1.3, there 
are two kinds of error we can make, known as a false negative or type I error, which corresponds 
to accidentally rejecting the null hypothesis Hp when it is true, and a false positive or type HI 
error which corresponds to accidentally accepting the null when the alternative is true. The type 
I error rate a is called the significance of the test. In our Gaussian mean example, we see from 
Figure 5.10a that the type I error rate is the vertical shaded blue area: 


a(uo) = p(type I error) = p(reject Ho|Ho is true) (5.101) 
= p(X (Ď) > z*|Ď ~ Ho) (5.102) 

X- mo _ a =] 
v( aiVN ? ofVN i 


Hence x«* = zag /V N + uo, where z, is the upper a quantile of the standard Normal. The type II 
error rate is denoted by 8, and is given by 


B(u1) = p(type II error) = p(accept Ho| Hı is true) = p(X (D) < a*|D~ Hı) (5.104) 


This is shown by the horizontal shaded red area in Figure 5.10a. 

We define the power of a test as 1— 8 (u1); this is the probability that we reject Ho given that H; is 
true. In other words, it is the ability to correctly recognize that the null hypothesis is wrong. Clearly 
the least power occurs if pı = Ho (so the curves overlap); in this case, we have 1 — 8(u1) = a(uo). As 
[41 and uo become further apart, the power approaches 1 (because the shaded red area gets smaller, 
B — 0). If we have two tests, A and B, where power(B) > power(A) for the same type I error rate, 
we say B dominates A. See Figure 5.10b. A test with highest power under Hı amongst all tests 
with significance level a is called a most powerful test. It turns out that the likelihood ratio test 
is a most powerful test, a result known as the Neyman-Pearson lemma. 


5.5.3 Null hypothesis significance testing (NHST) and p-values 


In the above decision-theoretic (or Neyman-Pearson) approach to hypothesis testing, we had to 
specify a null hypothesis Ho as well as an alternative hypothesis H; so that we can compute p(D|Hp) 
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and p(P|H1). In some cases, it is difficult to define an alternative hypothesis, and we just want 
to test if a simple null hypothesis is “plausible” given some data. To do this, we can define a test 
statistic test(D), and then we can compare its observed value to the value we would expect if the 
data came from the null hypothesis, test (D) where D ~ Hp. If the observed value is unexpected 
given Hy, we reject the null hypothesis. To quantity this, we compute the probability of seeing a 
test value that is as large or larger than the observed value (assuming that larger values make Hy 
more likely). More precisely, we define the p-value to be the probability, under the null hypothesis, 
of observing a test statistic that is as large or larger than that actually observed: 


pval £ Pr(test(D) > test(D)|D ~ Ho) (5.105) 


In other words, pval = Pr(testyy > testops), where testops = test(D) and testnun = test(D), where 
D ~ Ho is hypothetical future data. Smaller values correspond to stronger evidence against Ho. 

Traditionally we reject the null hypothesis if the p-value is less than a = 0.05; this is called the 
significance level of the test, and the whole approach is called null hypothesis significance 
testing or NHST. By construction, such a test will have a type I error rate (accidently rejecting 
the null when it is true) of value a. Note that this decision rule corresponds to picking decision 
threshold ¢* such that Pr(test(D) > t*|Ho) = a. If we set t* = test(D), then a will be equal to the 
observed p-value. Thus the p-value is the smallest value of a@ for which we can reject Ho. 

We can compute the p-value using pval = 1 — ®(test(D)), where ® is the cdf of the sampling 
distribution of the test statistic. This is called a one-sided p-value. In some case it can be 
more appropriate to use a two-sided p-value of the form pval = Pr(test(D) > test(D)|D ~ 
Ho) + Pr(test(D) < test(D)|D ~ Ho). For example, suppose we use test(D) = (6(D) — 9)/se(D), 
where ĝo is the value for 0* given Ho, and Ê is the MLE; ths is known as the Wald statistic. 
Based on the asymptotic normality of the MLE discussed in Section 4.7.2, we have that pval = 
Pr(|test(D)| > |test(D)| | Ho) ~ Pr(|Z| > |test(D)|) = 2@(—|test(D)|), where Z ~ N (0,1). 

We see that, to compute the p-value, we need to compute the sampling distribution of the test 
statistic under the null hypothesis. Often we use a large sample (Gaussian approxmation), as we 
illustrated above. However, in the case where we want to test the null hypothesis that two distributons 
are the same, we can use the non-parametric permutation test, which makes no assumptions about 
the distribution. For example, suppose we have m samples X; from Px and n samples Y; from Py 
and the null hypothesis is P, = P,. Define the test statistic test(X1,...,Xm,Yi,---, Yn) = |X — Y|. 
If we permute the order of the samples, then, under Hp, this statistic should not change. So we can 
sample random permutations to approximate p(test(D)|D ~ Ho), from which we can compute the 
tail probability of test(D) computed using the unshuffled data. For more details, see e.g. [Was04, 
p162]. 

Note that a p-value of 0.05 does not mean that the alternative hypothesis H; is true with probability 
0.95. Indeed, even most scientists misinterpret p-values.* The quantity that most people want to 
compute is the Bayesian posterior p(H|D). For more on this important distinction, see Section 5.5.4. 


5.5.4 p-values considered harmful 


A p-value is often interpreted as the likelihood of the data under the null hypothesis, so small values 
are interpreted to mean that Ho is unlikely, and therefore that H; is likely. The reasoning is roughly 


4. See e.g., https: //fivethirtyeight .com/features/not-even-scientists-can-easily-explain-p-values/. 
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Ineffective Effective 


“Not significant” 171 4 175 
“Significant” 9 16 25 
180 20 200 


Table 5.8: Some statistics of a hypothetical clinical trial. Source: [SAM04, p74]. 


as follows: 


If Ho is true, then this test statistic would probably not occur. This statistic did occur. 
Therefore Ho is probably false. 


However, this is invalid reasoning. To see why, consider the following example (from [Coh94]): 


If a person is an American, then he is probably not a member of Congress. This person is a 
member of Congress. Therefore he is probably not an American. 


This is obviously fallacious reasoning. By contrast, the following logical argument is valid reasoning: 


If a person is a Martian, then he is not a member of Congress. This person is a member of 
Congress. Therefore he is not a Martian. 


The difference between these two cases is that the Martian example is using deduction, that is, 
reasoning forward from logical definitions to their consequences. More precisely, this example uses a 
rule from logic called modus tollens, in which we start out with a definition of the form P > Q; 
when we observe ~Q, we can conclude ~P. By contrast, the American example concerns induction, 
that is, reasoning backwards from observed evidence to probable (but not necessarily true) causes 
using statistical regularities, not logical definitions. 

To perform induction, we need to use probabilistic inference (as explained in detail in [Jay03]). In 
particular, to compute the probability of the null hypothesis, we should use Bayes rule, as follows: 


p(D|Ho)p(Ho) 


p(Ho|D) = p(D|Ho)p( Ho) +p(D|Hı)p(H1) 


(5.106) 


If the prior is uniform, so p(Ho) = p(H1) = 0.5, this can be rewritten in terms of the likelihood 
ratio LR = p(D|Ho)/p(P|M1) as follows: 


LR 
LR+1 


p(Ho|D) = (5.107) 
In the American Congress example, D is the observation that the person is a member of Congress. 
The null hypothesis Ho is that the person is American, and the alternative hypothesis Hı is that the 
person is not American. We assume that p(D|Ho) is low, since most Americans are not members of 
Congress. However, p(D|H}) is also low — in fact, in this example, it is 0, since only Americans can 
be members of Congress. Hence LR = œ, so p(Ho|P) = 1.0, as intuition suggests. Note, however, 
that NHST ignores p(D|H;) as well as the prior p(Ho), so it gives the wrong results — not just in 
this problem, but in many problems. 
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In general there can be huge differences between p-values and p(Ho|D). In particular, [SBBO1|] 
show that even if the p-value is as low as 0.05, the posterior probability of Ho can be as high as 30% 
or more, even with a uniform prior. 

Consider this concrete example from [SAM04, p74]. Suppose 200 clinical trials are carried out for 
some drug, and we get the data in Table 5.8. Suppose we perform a statistical test of whether the 
drug has a significant effect or not. The test has a type I error rate of a = 9/180 = 0.05 and a type 
II error rate of 8 = 4/20 = 0.2. 

We can compute the probability that the drug is not effective, given that the result is supposedly 
“significant”, as follows: 


p(significant’|Ho)p(Ho) 


Ao|’signifi Y= d 
Di Ho oigani) pl significant’ |Ho)p(Ho) + p(significant’|H1)p(H1) ae) 
_ p(type I error)p(Ho) (5.109) 
p(type I error)p(Ho) + (1 — p(type II error))p( M4) 
eno) (5.110) 


~ ap(Ho) + (1 — B)p(H) 


If we have prior knowledge, based on past experience, that most (say 90%) drugs are ineffective, 
then we find p(Hpo|’significant’) = 0.36, which is much more than the 5% probability people usually 
associate with a p-value of a = 0.05. 

Thus we should distrust claims of statistical significance if they violate our prior knowledge. 


5.5.5 Why isn’t everyone a Bayesian? 


In Section 4.7.5 and Section 5.5.4, we have seen that inference based on frequentist principles can 
exhibit various forms of counter-intuitive behavior that can sometimes contradict common sense 
reason, as has been pointed out in multiple articles (see e.g., [Mat98; MS11; Krul3; Gell6; Hoe+14; 
Lyu+20; Cha+19b; Cla21]). 

The fundamental reason for these problems is that frequentist inference violates the likelihood 
principle [BW88], which says that inference should be based on the likelihood of the observed data, 
not on hypothetical future data that you have not observed. Bayes obviously satisfies the likelihood 
principle, and consequently does not suffer from these pathologies. 

Given these fundamental flaws of frequentist statistics, and the fact that Bayesian methods do not 
have such flaws, an obvious question to ask is: “Why isn’t everyone a Bayesian?” The (frequentist) 
statistician Bradley Efron wrote a paper with exactly this title [Efr86]. His short paper is well worth 
reading for anyone interested in this topic. Below we quote his opening section: 


The title is a reasonable question to ask on at least two counts. First of all, everyone used to 
be a Bayesian. Laplace wholeheartedly endorsed Bayes’s formulation of the inference problem, 
and most 19th-century scientists followed suit. This included Gauss, whose statistical work is 
usually presented in frequentist terms. 


A second and more important point is the cogency of the Bayesian argument. Modern 
statisticians, following the lead of Savage and de Finetti, have advanced powerful theoretical 
arguments for preferring Bayesian inference. A byproduct of this work is a disturbing catalogue 
of inconsistencies in the frequentist point of view. 
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DID THE SUN JUST EXPLODE? 
(ITS NIGHT, SO WERE NOT SURE.) 
THIS NEVIRINO DETECTOR MEASURES 
\WHETHER THE SUN HAS GONE NOVA. 
THEN, TROUS TWO DICE. IF THEY 
BOTH COME UP SIX, IT LES TO US. 
OTHERWISE, IT TELLS THE TRUH. 


HAPPENING BY CHANCE 1S 32 =0027, BET YOU $50 
SINCE p<0.05, CONCLUDE. 
HAS EXPLODED. ) 


Figure 5.11: Cartoon illustrating the difference between frequentists and Bayesians. (The p < 0.05 comment 
is explained in Section 5.5.4. The betting comment is a reference to the Dutch book theorem, which essentially 
proves that the Bayesian approach to gambling (and other decision theory problems) is optimal, as explained 
in e.g., [Háj08].) From https: //zkcd. com/1132/. Used with kind permission of Rundall Munroe (author 
of zkcd). 


Nevertheless, everyone is not a Bayesian. The current era (1986) is the first century in which 
statistics has been widely used for scientific reporting, and in fact, 20th-century statistics is 
mainly non-Bayesian. However, Lindley (1975) predicts a change for the 21st century. 


Time will tell whether Lindley was right. However, the trends seem to be going in this direction. 
For example, some journals have banned p-values [TM15; AGM19], and the journal The American 
Statistician (produced by the American Statistical Association) published a whole special issue 
warning about the use of p-values and NHST [WSL19]. 

Traditionally, computation has been a barrier to using Bayesian methods, but this is less of an issue 
these days, due to faster computers and better algorithms (which we will discuss in the sequel to this 
book, [Mur23]). Another, more fundamental, concern is that the Bayesian approach is only as correct 
as its modeling assumptions. However, this criticism also applies to frequentist methods, since the 
sampling distribution of an estimator must be derived using assumptions about the data generating 
mechanism. (In fact [BT73] show that the sampling distributions for the MLE for common models 
are identical to the posterior distributions under a noninformative prior.) Fortunately, we can check 
modeling assumptions empirically using cross validation (Section 4.5.5), calibration, and Bayesian 
model checking. We discuss these topics in the sequel to this book, [Mur23]. 

To summarize, it is worth quoting Donald Rubin, who wrote a paper [Rub84] called “Bayesianly 
Justifiable and Relevant Frequency Calculations for the Applied Statistician”. In it, he writes 
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The applied statistician should be Bayesian in principle and calibrated to the real world in 
practice. [They] should attempt to use specifications that lead to approximately calibrated pro- 
cedures under reasonable deviations from [their assumptions]. [They] should avoid models that 
are contradicted by observed data in relevant ways — frequency calculations for hypothetical 
replications can model a model’s adequacy and help to suggest more appropriate models. 


5.6 Exercises 


Exercise 5.1 [Reject option in classifiers] 


(Source: [DHS01, Q2.13].) In many classification problems one has the option either of assigning æ to class j 
or, if you are too uncertain, of choosing the reject option. If the cost for rejects is less than the cost of 
falsely classifying the object, it may be the optimal action. Let a; mean you choose action i, for i =1:C+1, 
where C is the number of classes and C + 1 is the reject action. Let Y = j be the true (but unknown) state 
of nature. Define the loss function as follows 


0 ift=j andi,j € {1,...,C} 
AlailY = j) = Ar ifi=C4+1 (5.111) 
As otherwise 


In other words, you incur 0 loss if you correctly classify, you incur Ar loss (cost) if you choose the reject 
option, and you incur Às loss (cost) if you make a substitution error (misclassification). 


a. Show that the minimum risk is obtained if we decide Y = j if p(Y = j|x) > p(Y = k|æ) for all k (ie., j is 
the most probable class) and if p(Y = j|a) > 1 — 3; otherwise we decide to reject. 


b. Describe qualitatively what happens as A,/As is increased from 0 to 1 (i.e., the relative cost of rejection 
increases). 


Exercise 5.2 [Newsvendor problem *] 

Consider the following classic problem in decision theory / economics. Suppose you are trying to decide how 
much quantity Q of some product (e.g., newspapers) to buy to maximize your profits. The optimal amount 
will depend on how much demand D you think there is for your product, as well as its cost to you C and its 
selling price P. Suppose D is unknown but has pdf f(D) and cdf F(D). We can evaluate the expected profit 
by considering two cases: if D > Q, then we sell all Q items, and make profit 7 = (P — C)Q; but if D < Q, 
we only sell D items, at profit (P — C)D, but have wasted C(Q — D) on the unsold items. So the expected 
profit if we buy quantity Q is 


o0 Q Q 
En(Q) = f (P — C)Qf(D)dD +f (P — C)Df(D)dD — - C(Q — D)f(D)dD (5.112) 


Simplify this expression, and then take derivatives wrt Q to show that the optimal quantity Q* (which 
maximizes the expected profit) satisfies 


P= 
hoje (5.113) 
P 
Exercise 5.3 [Bayes factors and ROC curves *] 


Let B = p(D|H1)/p(D|Ho) be the Bayes factor in favor of model 1. Suppose we plot two ROC curves, one 
computed by thresholding B, and the other computed by thresholding p(Hi|D). Will they be the same or 
different? Explain why. 


Exercise 5.4 [Posterior median is optimal estimate under L1 loss] 


Prove that the posterior median is the optimal estimate under L1 loss. 
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6 Information Theory 


In this chapter, we introduce a few basic concepts from the field of information theory. More 
details can be found in other books such as [Mac03; CTO06], as well as the sequel to this book, [Mur23]. 


6.1 Entropy 


The entropy of a probability distribution can be interpreted as a measure of uncertainty, or lack 
of predictability, associated with a random variable drawn from a given distribution, as we explain 
below. 

We can also use entropy to define the information content of a data source. For example, 
suppose we observe a sequence of symbols X, ~ p generated from distribution p. If p has high 
entropy, it will be hard to predict the value of each observation X,,. Hence we say that the dataset 
D =(X1,...,X») has high information content. By contrast, if p is a degenerate distribution with 0 
entropy (the minimal value), then every Xn will be the same, so D does not contain much information. 
(All of this can be formalized in terms of data compression, as we discuss in the sequel to this book.) 


6.1.1 Entropy for discrete random variables 
The entropy of a discrete random variable X with distribution p over K states is defined by 


K 


H (X) ê — $ p(X = k) logy p(X = k) = -Ex [log p(X)] (6.1) 
k=1 


(Note that we use the notation H (X) to denote the entropy of the rv with distribution p, just as 
people write V |X] to mean the variance of the distribution associated with X; we could alternatively 
write H(p).) Usually we use log base 2, in which case the units are called bits (short for binary 
digits). For example, if X € {1,...,5} with histogram distribution p = [0.25, 0.25, 0.2, 0.15, 0.15], we 
find H = 2.29 bits. If we use log base e, the units are called nats. 

The discrete distribution with maximum entropy is the uniform distribution. Hence for a K-ary 
random variable, the entropy is maximized if p(x = k) = 1/K; in this case, H (X) = log, K. To see 
this, note that 


K 
H(X)=- 5° L log(1/K) = —log(1/K) = log(K) (6.2) 
k=1 
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Figure 6.1: Entropy of a Bernoulli random variable as a function of 0. The maximum entropy is log, 2 = 1. 
Generated by bernoulli_ entropy fig.ipynb. 


Figure 6.2: (a) Some aligned DNA sequences. Each row is a sequence, each column is a location within 
the sequence. (b) The corresponding position weight matrix, visualized as a sequence of histograms. Each 
column represents a probability distribution over the alphabet {A,C,G,T} for the corresponding location in 
the sequence. The size of the letter is proportional to the probability. (c) A sequence logo. See text for details. 
Generated by seq_logo_ demo.ipynb. 
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Conversely, the distribution with minimum entropy (which is zero) is any delta-function that puts all 
its mass on one state. Such a distribution has no uncertainty. 

For the special case of binary random variables, X € {0,1}, we can write p(X = 1) = 0 and 
p(X =0) = 1—0. Hence the entropy becomes 


H(X) = —[p(X = 1) logy p(X = 1) + p(X = 0) logy p(X = 0)] (6.3) 
= —[0 log, 0 + (1 — 0) loga (1 — 8)] 
This is called the binary entropy function, and is also written H (0). We plot this in Figure 6.1. 


We see that the maximum value of 1 bit occurs when the distribution is uniform, 0 = 0.5. A fair coin 
requires a single yes/no question to determine its state. 


6.1.1.1 Application: DNA sequence logos 


As an interesting application of entropy, consider the problem of representing DNA sequence 
motifs, which is a distribution over short DNA strings. We can estimate this distribution by aligning 
a set of DNA sequences (e.g., from different species), and then estimating the empirical distribution 
of each possible nucleotide from the 4 letter alphabet X ~ {A,C,G,T} at each location t in the ith 
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sequence as follows: 


N N N N 
N; = (Son =A), X IXa =C), X I (Xa =G), S01 (Xie -n) (6.5) 
6; = N:/N, (6.6) 


This N; is a length four vector counting the number of times each letter appears at each location 
amongst the set of sequences. This Ô, distribution is known as a position weight matrix or a 
sequence motif. We can visualize this as shown in Figure 6.2b. Here we plot the letters A, C, G 
and T, where the size of letter k at location t is proportional to the empirical frequency 6:4. 

An alternative visualization, known as a sequence logo, is shown in Figure 6.2c. Each column 
is scaled by R; = 2 — H;, where H; is the entropy of 6, and 2 = log, (4) is the maximum possible 
entropy for a distribution over 4 letters. Thus a deterministic distribution, which has entropy 0 and 
thus maximal information content, has height 2. Such informative locations are highly conserved 
by evolution, often because they are part of a gene coding region. We can also just compute the 
most probable letter in each location, regardless of the uncertainty; this is called the consensus 
sequence. 


6.1.1.2 Estimating entropy 


Estimating the entropy of a random variable with many possible states requires estimating its 
distribution, which can require a lot of data. For example, imagine if X represents the identity of 
a word in an English document. Since there is a long tail of rare words, and since new words are 
invented all the time, it can be difficult to reliably estimate p(X) and hence H (X). For one possible 
solution to this problem, see [VV 13]. 


6.1.2 Cross entropy 


The cross entropy between distribution p and q is defined by 


K 
Hee (p, q) 4- So px log dk (6.7) 
k=1 


One can show that the cross entropy is the expected number of bits needed to compress some data 
samples drawn from distribution p using a code based on distribution g. This can be minimized by 
setting q = p, in which case the expected number of bits of the optimal code is Hee(p, p) = H(p) — 
this is known as Shannon’s source coding theorem (see e.g., [CT06]). 


6.1.3 Joint entropy 


The joint entropy of two random variables X and Y is defined as 


H (X,Y) = — $ ` p(z, y) logs p(z, y) (6.8) 


zy 
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For example, consider choosing an integer from 1 to 8, n € {1,...,8}. Let X(n) = 1 if n is even, and 
Y(n) = 1 if n is prime: 
n |1 2 3 4 5 6 7 8 
X/0 1 0 1 O 1 0 1 
Y | 0110101 0 
The joint distribution is 
o(X,Y)|¥=0 Y=1 
X=0 [2 3 
xe | | 
so the joint entropy is given by 
H(X,Y) = L logs ; + > log, : + : logs : H $ loge J = 1.81 bits (6.9) 


Clearly the marginal probabilities are uniform: p(X = 1) = p(X = 0) = p(Y = 0) = p(Y 
1) = 0.5, so H(X) = H(Y) = 1. Hence H(X,Y) = 1.81 bits < H(X) +H(Y) = 2 bits. In 
fact, this upper bound on the joint entropy holds in general. If X and Y are independent, then 
H(X,Y)=H(X)+H(Y), so the bound is tight. This makes intuitive sense: when the parts are 
correlated in some way, it reduces the “degrees of freedom” of the system, and hence reduces the 
overall entropy. 

What is the lower bound on H (X,Y)? If Y is a deterministic function of X, then H (X,Y) = H (X). 
So 


H (X,Y) > max{H(X),H(Y)} >0 (6.10) 


Intuitively this says combining variables together does not make the entropy go down: you cannot 
reduce uncertainty merely by adding more unknowns to the problem, you need to observe some data, 
a topic we discuss in Section 6.1.4. 

We can extend the definition of joint entropy from two variables to n in the obvious way. 


6.1.4 Conditional entropy 


The conditional entropy of Y given X is the uncertainty we have in Y after seeing X, averaged 
over possible values for X: 


H (Y|X) = E,cx) H E a (6.11) 
= Dre) p(Y|X = 2) =e du (ylz) log p(y|x) (6.12) 


= pi x,y) log p(y|z) = z2r x,y) log "A (6.13) 
= rx x,y) log p(x, y j+ Dre ) log p(x (6.14) 
= We (6.15) 


If Y is a deterministic function of X, then knowing X completely determines Y, so H(Y|X) = 0. 
If X and Y are independent, knowing X tells us nothing about Y and H(Y|X) = H(Y). Since 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


6.1. Entropy 209 


H(X,Y) <H(Y)+H(X), we have 
H(Y|X) < H(Y) (6.16) 


with equality iff X and Y are independent. This shows that, on average, conditioning on data never 
increases one’s uncertainty. The caveat “on average” is necessary because for any particular observation 
(value of X), one may get more “confused” (i.e., H(Y|x) > H(Y)). However, in expectation, looking 
at the data is a good thing to do. (See also Section 6.3.8.) 

We can rewrite Equation (6.15) as follows: 


H (X1, X2) = H (X1) + H (X2| X1) (6.17) 


This can be generalized to get the chain rule for entropy: 


H (Xi, Xo). Xa) = > H(X:|X1, -3 Xia) (6.18) 


6.1.5 Perplexity 
The perplexity of a discrete probability distribution p is defined as 


perplexity (p) £ 2#@) (6.19) 


This is often interpreted as a measure of predictability. For example, suppose p is a uniform 
distribution over K states. In this case, the perplexity is K. Obviously the lower bound on perplexity 
is 2° = 1, which will be achieved if the distribution can perfectly predict outcomes. 

Now suppose we have an empirical distribution based on data D: 


N 
p(z|D) = WA (x — tp) (6.20) 


We can measure how well p predicts D by computing 
perplexity (pp, p) £ gHce(pp p) (6.21) 
Perplexity is often used to evaluate the quality of statistical language models, which is a generative 


model for sequences of tokens. Suppose the data is a single long document x of length N, and 
suppose p is a simple unigram model. In this case, the cross entropy term is given by 


N 
1 
H=- y` 22 
N og p(n) (6.22) 


and hence the perplexity is given by 


perplexity(pp, p) = 2” = o7 W log(TIna1 P(@n)) — (6.23) 
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This is sometimes called the exponentiated cross entropy. We see that this is the geometric 
mean of the inverse predictive probabilities. 

In the case of language models, we usually condition on previous words when predicting the next 
word. For example, in a bigram model, we use a first order Markov model of the form p(xiļxi—-1). 
We define the branching factor of a language model as the number of possible words that can 
follow any given word. We can thus interpret the perplexity as the weighted average branching factor. 
For example, suppose the model predicts that each word is equally likely, regardless of context, so 
p(x;|a;-1) =1/K. Then the perplexity is ((1/K)%)~/" = K. If some symbols are more likely than 
others, and the model correctly reflects this, its perplexity will be lower than K. However, as we 
show in Section 6.2, we have H (p*) < Hee (p*, p), so we can never reduce the perplexity below the 
entropy of the underlying stochastic process p*. 

See [JM08, p96] for further discussion of perplexity and its uses in language models. 


6.1.6 Differential entropy for continuous random variables * 


If X is a continuous random variable with pdf p(x), we define the differential entropy as 


A(X) 4—- fro log p(x) dx (6.24) 


assuming this integral exists. For example, suppose X ~ U(0,a). Then 


T | 1 
h(X) = -f dz — log — = loga (6.25) 
0 a a 


Note that, unlike the discrete case, differential entropy can be negative. This is because pdf’s can be 
bigger than 1. For example if X ~ U(0,1/8), we have A(X) = log,(1/8) = —3. 

One way to understand differential entropy is to realize that all real-valued quantities can only be 
represented to finite precision. It can be shown [CT91, p228] that the entropy of an n-bit quantization 
of a continuous random variable X is approximately h(X) +n. For example, suppose X ~ U(0, z). 
Then in a binary representation of X, the first 3 bits to the right of the binary point must be 0 (since 
the number is < 1/8). So to describe X to n bits of accuracy only requires n — 3 bits, which agrees 
with h(X) = —3 calculated above. 


6.1.6.1 Example: Entropy of a Gaussian 


The entropy of a d-dimensional Gaussian is 
1 1 d d d 1 
h(N (u, £)) = z ln |27eX)| = F In[(27re)*|S|] = 5 + 5 In(277) + 5 In|] (6.26) 
In the 1d case, this becomes 
2 1 2 
h(N(u,0°)) = 7 In [27re07] (6.27) 
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6.1.6.2 Connection with variance 


The entropy of a Gaussian increases monotonically as the variance increases. However, this is not 
always the case. For example, consider a mixture of two 1d Gaussians centered at -1 and +1. As we 
move the means further apart, say to -10 and +10, the variance increases (since the average distance 
from the overall mean gets larger). However, the entropy remains more or less the same, since we are 
still uncertain about where a sample might fall, even if we know that it will be near -10 or +10. (The 
exact entropy of a GMM is hard to compute, but a method to compute upper and lower bounds is 
presented in [Hub+08].) 


6.1.6.3 Discretization 


In general, computing the differential entropy for a continuous random variable can be difficult. A 
simple approximation is to discretize or quantize the variables. There are various methods for this 
(see e.g., [DKS95; KK06] for a summary), but a simple approach is to bin the distribution based on 
its empirical quantiles. The critical question is how many bins to use [LM04]. Scott [Sco79] suggested 
the following heuristic: 


max(D) — min(D) 


Bane 
3.50(D) 


(6.28) 


where a(D) is the empirical standard deviation of the data, and N = |D] is the number of datapoints 
in the empirical distribution. However, the technique of discretization does not scale well if X is a 
multi-dimensional random vector, due to the curse of dimensionality. 


6.2 Relative entropy (KL divergence) * 


? 


Given two distributions p and q, it is often useful to define a distance metric to measure how “close’ 
or “similar” they are. In fact, we will be more general and consider a divergence measure D(p, q) 
which quantifies how far q is from p, without requiring that D be a metric. More precisely, we say 
that D is a divergence if D(p,q) > 0 with equality iff p = q, whereas a metric also requires that D be 
symmetric and satisfy the triangle inequality, D(p,r) < D(p,q)+D(q,r). There are many possible 
divergence measures we can use. In this section, we focus on the Kullback-Leibler divergence 
or KL divergence, also known as the information gain or relative entropy, between two 
distributions p and q. 


6.2.1 Definition 


For discrete distributions, the KL divergence is defined as follows: 


K 
Pk 
Dut (p || a) ê $. pr log — (6.29) 
k=l qk 
This naturally extends to continuous distributions as well: 


Dia. (vl) a) È f de v(x) log) (6.30) 
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6.2.2 Interpretation 


We can rewrite the KL as follows: 


K K 
Deu (p || a) = X- pe log pe — X. pe log qx (6.31) 
k=1 k=1 
— H(p) Hee (p,q) 


We recognize the first term as the negative entropy, and the second term as the cross entropy. It 
can be shown that the cross entropy Hee(p,q) is a lower bound on the number of bits needed to 
compress data coming from distribution p if your code is designed based on distribution q; thus we 
can interpret the KL divergence as the “extra number of bits” you need to pay when compressing 
data samples if you use the incorrect distribution q as the basis of your coding scheme compared to 
the true distribution p. 

There are various other interpretations of KL divergence. See the sequel to this book, [Mur23], for 
more information. 


6.2.3 Example: KL divergence between two Gaussians 


For example, one can show that the KL divergence between two multivariate Gaussian distributions 
is given by 


Dri (N (#|Hy, 21) || N (zlu, B2)) 


= 5 [EŒ Ea) + (ue -mE (a = ma) = D+ 10g (SEE) (6.32) 


In the scalar case, this becomes 


o o2 + yet i 
Dy (N (x|u1, 01) | N (x| u2, 02)) = log . | 1 (u u2) 


6.33 
1 202 2 ( ) 


6.2.4 Non-negativity of KL 


In this section, we prove that the KL divergence is always non-negative. 
To do this, we use Jensen’s inequality. This states that, for any convex function f, we have 
that 


fu ALi) < j3 Aif (xi) (6.34) 


where A; > 0 and $*"_, \; = 1. In words, this result says that f of the average is less than the 

average of the f’s. This is clearly true for n = 2, since a convex function curves up above a straight 

line connecting the two end points (see Section 8.1.3). To prove for general n, we can use induction. 
For example, if f(x) = log(x), which is a concave function, we have 


log(Exg(x)) > Es log(g(x)) (6.35) 


We use this result below. 
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Theorem 6.2.1. (Information inequality) Dx (p || q) > 0 with equality iff p = q. 


Proof. We now prove the theorem following [CT06, p28]. Let A = {x : p(x) > 0} be the support of 
p(a). Using the concavity of the log function and Jensen’s inequality (Section 6.2.4), we have that 


= X p(2) log ale) (6.36) 


o O a pe 


-Dru (p || ¢) = — X` p(x) 


ZEA 


< log X. p(z) a = log X` g(a (6.37) 


p(x) 


ZEA TEA 
< log X` q(x) = log1 = 0 (6.38) 
LEX 


Since log(z) is a strictly concave function (— log(a) is convex), we have equality in Equation (6.37) iff 
p(x) = cq(x) for some c that tracks the fraction of the whole space ¥ contained in A. We have equality 
in Equation (6.38) iff \y,¢4 a(x) = Xsex q(x) = 1, which implies c = 1. Hence Dxz (p || q) = 0 iff 
p(x) = q(x) for all z. 


This theorem has many important implications, as we will see throughout the book. For example, 
we can show that the uniform distribution is the one that maximizes the entropy: 


Corollary 6.2.1. (Uniform distribution maximizes the entropy) H(X) < log|£|, where |X| is the 
number of states for X, with equality iff p(x) is uniform. 


Proof. Let u(x) = 1/|4#|. Then 


p(x) 
0 < Dri (p || u) =Z )log = = log|¥|- H(X) (6.39) 
u(x) 


6.2.5 KL divergence and MLE 


Suppose we want to find the distribution q that is as close as possible to p, as measured by KL 
divergence: 


q* =argmin Dri (p || q) = arg min J p(2) log p(2)de — | p(x) logq(2)da (6.40) 


Now suppose p is the empirical distribution, which puts a probability atom on the observed training 
data and zero mass everywhere else: 


N 


pplz) = a 5 (£ — Lp) (6.41) 
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Using the sifting property of delta functions we get 


Dy (pp || ¢) =— i pplz) log q()dx +C (6.42) 
= -f Poa log q(x)dx + C (6.43) 
= ~~ 5 log q(an) + C (6.44) 


where C = f p(x) logp(x)dx is a constant independent of q. This is called the cross entropy 
objective, and is equal to the average negative log likelihood of q on the training set. Thus we see 
that minimizing KL divergence to the empirical distribution is equivalent to maximizing likelihood. 

This perspective points out the flaw with likelihood-based training, namely that it puts too 
much weight on the training set. In most applications, we do not really believe that the empirical 
distribution is a good representation of the true distribution, since it just puts “spikes” on a finite 
set of points, and zero density everywhere else. Even if the dataset is large (say 1M images), the 
universe from which the data is sampled is usually even larger (e.g., the set of “all natural images” is 
much larger than 1M). We could smooth the empirical distribution using kernel density estimation 
(Section 16.3), but that would require a similar kernel on the space of images. An alternative, 
algorithmic approach is to use data augmentation, which is a way of perturbing the observed 
data samples in way that we believe reflects plausible “natural variation”. Applying MLE on this 
augmented dataset often yields superior results, especially when fitting models with many parameters 
(see Section 19.1). 


6.2.6 Forward vs reverse KL 


Suppose we want to approximate a distribution p using a simpler distribution q. We can do this by 
minimizing Dz (q || p) or Dex (p || q). This gives rise to different behavior, as we discuss below. 
First we consider the forwards KL, also called the inclusive KL, defined by 


Dut (p || a) = J r(o)t08 Bae (6.45) 
q(x) 
Minimizing this wrt q is known as an M-projection or moment projection. 

We can gain an understanding of the optimal q by considering inputs x for which p(x) > 0 but 
q(x) = 0. In this case, the term log p(x)/q(x) will be infinite. Thus minimizing the KL will force q 
to include all the areas of space for which p has non-zero probability. Put another way, q will be 
zero-avoiding or mode-covering, and will typically over-estimate the support of p. Figure 6.3(a) 
illustrates mode covering where p is a bimodal distribution but q is unimodal. 

Now consider the reverse KL, also called the exclusive KL: 


Da (allp) = f ale) 08 we ae (6.46) 


Minimizing this wrt q is known as an I-projection or information projection. 
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(c) 


Figure 6.3: Illustrating forwards vs reverse KL on a bimodal distribution. The blue curves are the contours of 
the true distribution p. The red curves are the contours of the unimodal approximation q. (a) Minimizing 
forwards KL, Dxu(p || q), wrt q causes q to “cover” p. (b-c) Minimizing reverse KL, Deru (q || p) wrt q 
causes q to “lock onto” one of the two modes of p. Adapted from Figure 10.3 of [Bis06]. Generated by 
KLfwdReverseMizGauss.ipynb. 


We can gain an understanding of the optimal q by consider inputs x for which p(x) = 0 but 
q(x) > 0. In this case, the term log q(x) /p(x) will be infinite. Thus minimizing the exclusive KL will 
force q to exclude all the areas of space for which p has zero probability. One way to do this is for q 
to put probability mass in very few parts of space; this is called zero-forcing or mode-seeking 
behavior. In this case, q will typically under-estimate the support of p. We illustrate mode seeking 
when p is bimodal but q is unimodal in Figure 6.3(b-c). 


6.3 Mutual information * 


The KL divergence gave us a way to measure how similar two distributions were. How should we 
measure how dependant two random variables are? One thing we could do is turn the question 
of measuring the dependence of two random variables into a question about the similarity of their 
distributions. This gives rise to the notion of mutual information (MI) between two random 
variables, which we define below. 


6.3.1 Definition 


The mutual information between rv’s X and Y is defined as follows: 


1(X;¥) 4 Dux (p(2,9) lawpa = = E rey) 10s ee (6.47) 


(We write I(X;Y) instead of I(X,Y), in case X and/or Y represent sets of variables; for example, we 
can write I(X;Y, Z) to represent the MI between X and (Y, Z).) For continuous random variables, 
we just replace sums with integrals. 

It is easy to see that MI is always non-negative, even for continuous random variables, since 


I(X;Y) = Dri (P(x,y) || p(@)p(y)) = 0 (6.48) 
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We achieve the bound of 0 iff p(x, y) = p(x)p(y). 


6.3.2 Interpretation 


Knowing that the mutual information is a KL divergence between the joint and factored marginal 
distributions tells us that the MI measures the information gain if we update from a model that 
treats the two variables as independent p(a)p(y) to one that models their true joint density p(x, y). 

To gain further insight into the meaning of MI, it helps to re-express it in terms of joint and 
conditional entropies, as follows: 


I(X;Y) =H(X) —-H(X|Y) = H(Y) -H(Y|X) (6.49) 


Thus we can interpret the MI between X and Y as the reduction in uncertainty about X after 
observing Y, or, by symmetry, the reduction in uncertainty about Y after observing X. Incidentally, 
this result gives an alternative proof that conditioning, on average, reduces entropy. In particular, we 
have 0 <1(X;Y) =H(X) —H(X|Y), and hence H(X|Y) < H(X). 

We can also obtain a different interpretation. One can show that 


I(X;Y) =H(X,Y) -H(X|Y) -H(Y|X) (6.50) 
Finally, one can show that 
I(X;Y) =H(X)+H(Y) -H(X,Y) (6.51) 


See Figure 6.4 for a summary of these equations in terms of an information diagram. (Formally, 
this is a signed measure mapping set expressions to their information-theoretic counterparts [Yeu91].) 


6.3.3 Example 


As an example, let us reconsider the example concerning prime and even numbers from Section 6.1.3. 
Recall that H(X) =H (Y) = 1. The conditional distribution p(Y|X) is given by normalizing each 
row: 


| Y=0 Y=1 

mali 

X=1 | 3 i 
Hence the conditional entropy is 

1 1 3 3 3 3 1 1 . 
H(Y|X) = 5 logs 7 + 3 logs Z + 3 logs 3 + 3 logs ‘| = 0.81 bits (6.52) 
and the mutual information is 
I(X;Y) =H(Y) —-H(Y|X) = (1 — 0.81) bits = 0.19 bits (6.53) 
You can easily verify that 
H(X,Y) =H(X|Y) +1(X;Y) + H(Y|X) (6.54) 
= (0.81 + 0.19 + 0.81) bits = 1.81 bits (6.55) 
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Entropy Joint Entcopy 
ESS) Ay) 
W(x) DX HY) >Y H(x,¥) > XUY 
Mutual Information Conditional Entropy 
(8) ENS 
I(X;Y > xay H(xIY) =) X-Y H(¥Ix) }Y-x 


Figure 6.4: The marginal entropy, joint entropy, conditional entropy and mutual information represented as 
information diagrams. Used with kind permission of Katie Everett. 


6.3.4 Conditional mutual information 


We can define the conditional mutual information in the obvious way 


I(X;Y|Z) 4 E (Z) I(X;Y )|Z] (6. 56) 
nous 

= Epema) [os CR NCE | ven 

=H (X|Z) +H(¥|Z) - H (X,Y |2) (6.58) 

=H (X|Z) - H(X|Y, Z) = H (Y |Z) - H (Y |X, Z) (6.59) 

= H(X, Z) +H (Y, Z) — H (Z) — H (X,Y, Z) (6.60) 

-I(Y; X, Z) - I(Y; Z) (6.61) 


The last equation tells us that the conditional MI is the extra (residual) information that X tells us 
about Y, excluding what we already knew about Y given Z alone. 
We can rewrite Equation (6.61) as follows: 


I(Z,Y; X) =1I(Z;X)+I(Y; X|Z) (6.62) 
Generalizing to N variables, we get the chain rule for mutual information: 


N 
I(Zi,..., ZN; X =SUI (Zee i cerry hea) (6.63) 


n=1 
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6.3.5 MI as a “generalized correlation coefficient” 


Suppose that (x,y) are jointly Gaussian: 


G) (0 (me) A 


We now show how to compute the mutual information between X and Y. 
Using Equation (6.26), we find that the entropy is 


h(X,Y) = Z log [(27e)” det £] = Z log [(27e)*o*(1 — p*)] (6.65) 


Since X and Y are individually normal with variance o°, we have 


h(X) =h(Y) = 5 log [2re0°] (6.66) 
Hence 

I(X,Y) =h(X)+hA(Y)-h(X,Y) (6.67) 

= log[27e0?] — 5 losl(2ne)%o%(1 — p”)] (6.68) 

= 5 log[(2neo?)?| - 5 logl(2neo?)?(1 — p*)| (6.69) 

" ; log- a ws ; fasta] (6.70) 


We now discuss some interesting special cases. 


1. p= 1. In this case, X = Y, and I(X,Y) = co, which makes sense. Observing Y tells us an infinite 
amount of information about X (as we know its real value exactly). 


2. p =0. In this case, X and Y are independent, and (X,Y) = 0, which makes sense. Observing Y 
tells us nothing about X. 


3. p= —1. In this case, X = —Y, and I(X,Y) = ov, which again makes sense. Observing Y allows 
us to predict X to infinite precision. 


Now consider the case where X and Y are scalar, but not jointly Gaussian. In general it can be 
difficult to compute the mutual information between continuous random variables, because we have 
to estimate the joint density p(X, Y). For scalar variables, a simple approximation is to discretize 
or quantize them, by dividing the ranges of each variable into bins, and computing how many values 
fall in each histogram bin [Sco79]. We can then easily compute the MI using the empirical pmf. 

Unfortunately, the number of bins used, and the location of the bin boundaries, can have a 
significant effect on the results. One way to avoid this is to use K-nearest neighbor distances to 
estimate densities in a non-parametric, adaptive way. This is the basis of the KSG estimator for MI 
proposed in [KSG04]. This is implemented in the sklearn.feature_selection.mutual_info_ regression 
function. For papers related to this estimator, see [GOV18; HN19]. 
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6.3.6 Normalized mutual information 


For some applications, it is useful to have a normalized measure of dependence, between 0 and 1. We 
now discuss one way to construct such a measure. 
First, note that 


I(X;Y) =H(X) —H(X|Y) < H(X) (6.71) 
=H(Y) —H(Y|X) <H(Y) (6.72) 
0 <1(X;Y) < min(H(X),H(Y)) (6.73) 


Therefore we can define the normalized mutual information as follows: 


I(X;Y) 
min E, HY =! ee) 


NMI(X,Y) = 


This normalized mutual information ranges from 0 to 1. When NMI(X,Y) = 0, we have 
I(X;Y) =0, so X and Y are independent. When NMI(X,Y) = 1, and H(X) < H(Y), we have 


I(X;Y) =H(X) —H(X|Y) =H(X) = H(X|Y)=0 (6.75) 


and so X is a deterministic function of Y. For example, suppose X is a discrete random variable 
with pmf [0.5,0.25, 0.25]. We have MI(X, X) = 1.5 (using log base 2), and H(X) = 1.5, so the 
normalized MI is 1, as is to be expected. 

For continuous random variables, it is harder to normalize the mutual information, because of 
the need to estimate the differential entropy, which is sensitive to the level of quantization. See 
Section 6.3.7 for further discussion. 


6.3.7 Maximal information coefficient 


As we discussed in Section 6.3.6, it is useful to have a normalized estimate of the mutual information, 
but this can be tricky to compute for real-valued data. One approach, known as the maximal 
information coefficient (MIC) [Res+11], is to define the following quantity: 


MIC(X, Y) = max IX, Ya) 


6.76 
eos [IG ene) 


where G is the set of 2d grids, and (X,Y)|g represents a discretization of the variables onto this 
grid, and ||G|| is min(G,, Gy), where G, is the number of grid cells in the x direction, and Gy is 
the number of grid cells in the y direction. (The maximum grid resolution depends on the sample 
size n; they suggest restricting grids so that GyGy < B(n), where B(n) = n“, where a = 0.6.) The 
denominator is the entropy of a uniform joint distribution; dividing by this ensures 0 < MIC < 1. 
The intuition behind this statistic is the following: if there is a relationship between X and Y, 
then there should be some discrete gridding of the 2d input space that captures this. Since we don’t 
know the correct grid to use, MIC searches over different grid resolutions (e.g., 2x2, 2x3, etc), as well 
as over locations of the grid boundaries. Given a grid, it is easy to quantize the data and compute 
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Figure 6.5: Illustration of how the maximal information coefficient (MIC) is computed. (a) We search over 
different grid resolutions, and grid cell locations, and compute the MI for each. (b) For each grid resolution 
(k,l), we define set M(k,1) to be the maximum MI for any grid of that size, normalized by log(min(k,!)). (c) 
We visualize the matrix M. The maximum entry (denoted by a star) is defined to be the MIC. From Figure 1 
of [Res+11]. Used with kind permission of David Reshef. 


MI. We define the characteristic matrix M (k,l) to be the maximum MI achievable by any grid 
of size (k,l), normalized by log(min(k,/)). The MIC is then the maximum entry in this matrix, 
maXkı<B(n) M (k,l). See Figure 6.5 for a visualization of this process. 

In [Res+11], they show that this quantity exhibits a property known as equitability, which means 
that it gives similar scores to equally noisy relationships, regardless of the type of relationship (e.g., 
linear, non-linear, non-functional). 

In [Res+16], they present an improved estimator, called MICe, which is more efficient to compute, 
and only requires optimizing over 1d grids, which can be done in O(n) time using dynamic program- 
ming. They also present another quantity, called TICe (total information content), that has higher 
power to detect relationships from small sample sizes, but lower equitability. This is defined to be 
D Bin) M (k,l). They recommend using TICe to screen a large number of candidate relationships, 
and then using MICe to quantify the strength of the relationship. For an efficient implementation of 
both of these metrics, see [Alb+18]. 

We can interpret MIC of 0 to mean there is no relationship between the variables, and 1 to represent 
a noise-free relationship of any form. This is illustrated in Figure 6.6. Unlike correlation coefficients, 
MIC is not restricted to finding linear relationships. For this reason, the MIC has been called “a 
correlation for the 21st century” [Spel 1]. 

In Figure 6.7, we give a more interesting example, from [Res+11]. The data consists of 357 variables 
measuring a variety of social, economic, health and political indicators, collected by the World Health 
Organization (WHO). On the left of the figure, we see the correlation coefficient (CC) plotted against 
the MIC for all 63,546 variable pairs. On the right of the figure, we see scatter plots for particular 
pairs of variables, which we now discuss: 


e The point marked C (near 0,0 on the plot) has a low CC and a low MIC. The corresponding 
scatter plot makes it clear that there is no relationship between these two variables (percentage of 
lives lost to injury and density of dentists in the population). 


e The points marked D and H have high CC (in absolute value) and high MIC, because they 
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Pearson r=1.0 Pearson r=0.8 Pearson r=0.4 Pearson r=0.0 Pearson r=-0.5 Pearson r=-0.8 Pearson r=-1.0 
MIC=0.2 MIC=0.1 MIC=0.2 MIC=0.5 MIC=1.0 


Figure 6.6: Plots of some 2d distributions and the corresponding estimate of correlation coefficient R? and the 
maximal information coefficient (MIC). Compare to Figure 3.1. Generated by MIC correlation  2d.ipynb. 


represent nearly linear relationships. 


e The points marked E, F, and G have low CC but high MIC. This is because they correspond 
to non-linear (and sometimes, as in the case of E and F, non-functional, i.e., one-to-many) 
relationships between the variables. 


6.3.8 Data processing inequality 


Suppose we have an unknown variable X, and we observe a noisy function of it, call it Y. If we 
process the noisy observations in some way to create a new variable Z, it should be intuitively obvious 
that we cannot increase the amount of information we have about the unknown quantity, X. This is 
known as the data processing inequality. We now state this more formally, and then prove it. 


Theorem 6.3.1. Suppose X + Y —> Z forms a Markov chain, so that X L Z|Y. Then I(X;Y) > 
1(X;Z). 


Proof. By the chain rule for mutual information (Equation (6.62)), we can expand the mutual 
information in two different ways: 


I(X;Y, Z) =1(X;Z) + 1(X;Y|Z) (6.77) 
=1(X;¥) +1(X;2Z|Y) (6.78) 


Since X L Z|Y, we have I (X; Z|Y) = 0, so 
1(X;Z) +1(X;Y|Z) =1(Xx;Y) (6.79) 


Since I(X;Y|Z) > 0, we have I(X;Y) >1(X; Z). Similarly one can prove that I (Y; Z) > I (X; Z). 
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Figure 6.7: Left: Correlation coefficient vs maximal information criterion (MIC) for all pairwise relationships 
in the WHO data. Right: scatter plots of certain pairs of variables. The red lines are non-parametric 
smoothing regressions fit separately to each trend. From Figure 4 of [Res+11]. Used with kind permission of 
David Reshef. 


6.3.9 Sufficient Statistics 


An important consequence of the DPI is the following. Suppose we have the chain 6 > D > s(D). 
Then 


I (0; s(D)) <1(0;D) (6.80) 


If this holds with equality, then we say that s(D) is a sufficient statistic of the data D for the 
purposes of inferring @. In this case, we can equivalently write 0 + s(D) > D, since we can 
reconstruct the data from knowing s(D) just as accurately as from knowing 8. 

An example of a sufficient statistic is the data itself, s(D) = D, but this is not very useful, since 
it doesn’t summarize the data at all. Hence we define a minimal sufficient statistic s(D) as one 
which is sufficient, and which contains no extra information about 0; thus s(D) maximally compresses 
the data D without losing information which is relevant to predicting 0. More formally, we say s is a 
minimal sufficient statistic for D if for all sufficient statistics s’(D) there is some function f such that 
s(D) = f(s'(D)). We can summarize the situation as follows: 


80 > s(D) > s'(D) + D (6.81) 


Here s’(D) takes s(D) and adds redundant information to it, thus creating a one-to-many mapping. 

For example, a minimal sufficient statistic for a set of N Bernoulli trials is simply N and N; = 
>, 1 (Xn = 1), i.e., the number of successes. In other words, we don’t need to keep track of the 
entire sequence of heads and tails and their ordering, we only need to keep track of the total number 
of heads and tails. Similarly, for inferring the mean of a Gaussian distribution with known variance 
we only need to know the empirical mean and number of samples. 
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6.3.10 Fano’s inequality * 


A common method for feature selection is to pick input features Xq which have high mutual 
information with the response variable Y. Below we justify why this is a reasonable thing to do. 
In particular, we state a result, known as Fano’s inequality, which bounds the probability of 
misclassification (for any method) in terms of the mutual information between the features X and 
the class label Y. 


Theorem 6.3.2. (Fano’s inequality) Consider an estimator Y = f(X) such that Y > X > Ý forms 
a Markov chain. Let E be the event Y #Y , indicating that an error occured, and let P. = P(Y # Y) 
be the probability of error. Then we have 


H(Y|X) <H (xi?) < H (E) + P, log |V| (6.82) 


Since H (E) < 1, as we saw in Figure 6.1, we can weaken this result to get 


1+ P-. log |V| > H(Y|X) (6.83) 
and hence 
H(Y|X)—-1 
e2 — ee aa 6.84 
log |X| ven 


Thus minimizing H (Y|X) (which can be done by maximizing 1(X;Y)) will also minimize the lower 
bound on P.. 


Proof. (From [CT06, p38].) Using the chain rule for entropy, we have 


H (z, y|¥) =H (xi?) +H (IY, r) (6.85) 
——n m 
=0 
=H (zf) +H (YIE, x (6.86) 


Since conditioning reduces entropy (see Section 6.2.4), we have H (£E 7) <H(£). The final term 
can be bounded as follows: 


H (VIE, r) = P(E =0)H (vIY. E = 0) +P(E=1)H (vIY, E= 1) (6.87) 
< (1 — P;)0 + P; log [V] (6.88) 
Hence 
H (ri?) <H (zf) +H (Yiz, ?) (6.89) 
— + 
<H(E) Pe log |Y| 


Finally, by the data processing inequality, we have I(Y;Y) < I(Y ; X), so H (Y| X) < H (xi?) , which 
establishes Equation (6.82). 
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6.4 Exercises 


Exercise 6.1 [Expressing mutual information in terms of entropies *| 
Prove the following identities: 

I(X;Y) = H(X) — H(X|Y) = A(Y) - A(Y|X) 
and 


(X,Y) = H(X|Y) + H(Y|X) + I(X;Y) 


Exercise 6.2 [Relationship between D(p||q) and x? statistic] 
(Source: [CT91, Q12.2].) 
Show that, if p(x) ~ q(x), then 


1 
Dux (p || a) ~ 5x” 


where 


x? _ y (p(z) — q(a))? 


Hint: write 


p(x) = A(x) + a(z) 


pe), , Ala) 
q(x) q(x) 
and use the Taylor series expansion for log(1 + x). 
2 3 4 
x x x 
log(1 + x) = x a tg Z 


for -l<a<l. 


Exercise 6.3 [Fun with entropies *] 
(Source: Mackay.) 
Consider the joint distribution p(X, Y) 


1 2 3 4 


Chapter 6. Information Theory 


(6.90) 


(6.91) 


(6.92) 


(6.93) 


(6.94) 


(6.95) 


(6.96) 


1/8 1/16 1/32 1/32 
1/16 1/8 1/32 1/32 
1/16 1/16 1/16 1/16 
1/4 0 0 0 


Ae Ne 


a. What is the joint entropy H(X, Y)? 
b. What are the marginal entropies H(X) and H(Y)? 
c. The entropy of X conditioned on a specific value of y is defined as 


H(X|Y =y) =— > p(z|y) log p(z|y) 


x 


(6.97) 


Compute H(X|y) for each value of y. Does the posterior entropy on X ever increase given an observation 


of Y? 
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d. The conditional entropy is defined as 


H(X|Y) = LP) H(X|Y =y) (6.98) 


Compute this. Does the posterior entropy on X increase or decrease when averaged over the possible 
values of Y? 


e. What is the mutual information between X and Y? 


Exercise 6.4 [Forwards vs reverse KL divergence] 


(Source: Exercise 33.7 of [Mac03].) Consider a factored approximation q(x, y) = 9(x)q(y) to a joint distribution 
p(x, y). Show that to minimize the forwards KL Dx (p || q) we should set q(x) = p(x) and q(y) = p(y), i.e., 
the ‘optimal approximation is a product of marginals. 
Now consider the following joint distribution, where the rows represent y and the columns zx. 
|1 2 3 4 

1 | 1/8 1/8 0 0 

2| 1/8 1/8 0 0 

3 | 0 0 1/4 0 

410 0 0 1/4 


Show that the reverse KL Dxx (q || p) for this p has three distinct minima. Identify those minima and evaluate 
Dx (q || p) at each of them. What is the value of Dri (q || p) if we set q(x, y) = p(x)p(y)? 
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This chapter is co-authored with Zico Kolter. 


7.1 Introduction 


Linear algebra is the study of matrices and vectors. In this chapter, we summarize the key material 
that we will need throughout the book. Much more information can be found in other sources, such 
as [Str09; Ips09; Kle13; Mol04; TB97; Axl15; Tho17; Agg20]. 


7.1.1 Notation 


In this section, we define some notation. 


7.1.1.1 Vectors 


A vector x € R” is a list of n numbers, usually written as a column vector 


The vector of all ones is denoted 1. The vector of all zeros is denoted 0. 
The unit vector e; is a vector of all 0’s, except entry i, which has value 1: 


e; = (0,...,0,1,0,...,0) (7.2) 


This is also called a one-hot vector. 


7.1.1.2 Matrices 


A matrix A € R”*” with m rows and n columns is a 2d array of numbers, arranged as follows: 


A=|o0 P "|, (7.3) 
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If m =n, the matrix is said to be square. 

We use the notation A;; or A; j to denote the entry of A in the ith row and jth column. We use 
the notation A; to denote the i'th row and A. j to denote the j’th column. We treat all vectors as 
column vectors by default (so A;,, is viewed as a column vector with n entries). We use bold upper 
case letters to denote matrices, bold lower case letters to denote vectors, and non-bold letters to 
denote scalars. 

We can view a matrix as a set of columns stacked along the horizontal axis: 


| | | 
A= Alt A. 2 pas A:n : (7.4) 


For brevity, we will denote this by 
A =[A.1,A.2,..., An] (7.5) 
We can also view a matrix as a set of rows stacked along the vertical axis: 
= Al. a 
Es A}. EEN 
— a —— 
For brevity, we will denote this by 
A = JAA Am, (7.7) 


(Note the use of a semicolon.) 
The transpose of a matrix results from “flipping” the rows and columns. Given a matrix A € R™*", 
its transpose, written A’ € R"*™, is defined as 


(Alig = Aji (7.8) 


The following properties of transposes are easily verified: 


(AT =A (7.9) 
(AB)' = B'A! (7.10) 
(A+B)! =A'+B' (7.11) 


If a square matrix satisfies A = AT, it is called symmetric. We denote the set of all symmetric 
matrices of size n as S”. 


7.1.1.3 Tensors 


A tensor (in machine learning terminology) is just a generalization of a 2d array to more than 2 
dimensions, as illustrated in Figure 7.1. For example, the entries of a 3d tensor are denoted by Ajj. 
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Vector 


Matrix Tensor 


64 8x8 4x4x4 
R R R 


Figure 7.1: Illustration of a 1d vector, 2d matriz, and 3d tensor. The colors are used to represent individual 
entries of the vector; this list of numbers can also be stored in a 2d matriz, as shown. (In this example, the 
matrix is layed out in column-major order, which is the opposite of that used by Python.) We can also reshape 
the vector into a 3d tensor, as shown. 


Row-major order Column-major order 
3 
3 
3 | L J 
(a) (e) 


Figure 7.2: Illustration of (a) row-major vs (b) column-major order. From https: // commons. wikimedia. 
org/wiki/ File: Row_and_column_major_ order. sug. Used with kind permission of Wikipedia author 
Cmglee. 


The number of dimensions is known as the order or rank of the tensor.! In mathematics, tensors 
can be viewed as a way to define multilinear maps, just as matrices can be used to define linear 
functions, although we will not need to use this interpretation. 

We can reshape a matrix into a vector by stacking its columns on top of each other, as shown in 
Figure 7.1. This is denoted by 


vec(A) = [A 1; A; n] E R™*? (7.12) 
Conversely, we can reshape a vector into a matrix. There are two choices for how to do this, known 
as row-major order (used by languages such as Python and C++) and column-major order 
(used by languages such as Julia, Matlab, R and Fortran). See Figure 7.2 for an illustration of the 


difference. 


1. Note, however, that the rank of a 2d matrix is a different concept, as discussed in Section 7.1.4.3. 
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Figure 7.3: (a) Top: A vector v (blue) is added to another vector w (red). Bottom: w is stretched by a factor 
of 2, yielding the sum v + 2w. From https: //en. wikipedia. org/wiki/Vector_ space. Used with kind 
permission of Wikipedia author IkamusumeFan (b) A vector v in R? (blue) expressed in terms of different 
bases: using the standard basis of R?, v = re, + yez (black), and using a different, non-orthogonal basis: 
v = fit f2 (red). From https: // en. wikipedia. org/wiki/Vector_ space. Used with kind permission of 
Wikiepdia author Jakob.scholbach. 


7.1.2 Vector spaces 


In this section, we discuss some fundamental concepts in linear algebra. 


7.1.2.1 Vector addition and scaling 


We can view a vector x € R” as defining a point in n-dimensional Euclidean space. A vector space 
is a collection of such vectors, which can be added together, and scaled by scalars (1-dimensional 
numbers), in order to create new points. These operations are defined to operate elementwise, in 
the obvious way, namely # + y = (£1 + y1,.--,%n + Yn) and cz = (cz1,...,C£n), where c € R. See 
Figure 7.3a for an illustration. 


7.1.2.2 Linear independence, spans and basis sets 


A set of vectors {£1, £2, ... Zn } is said to be (linearly) independent if no vector can be represented 
as a linear combination of the remaining vectors. Conversely, a vector which can be represented as a 
linear combination of the remaining vectors is said to be (linearly) dependent. For example, if 


n—-1 
Ln = 5 Aiti (7.13) 
i=1 


for some {a1,...,@n—1} then x, is dependent on {æ£1,...,£n—1}; otherwise, it is independent of 
{x1, eae ,Zn—1}- 

The span of a set of vectors {£1, £2, ..., £n } is the set of all vectors that can be expressed as a 
linear combination of {£1,..., £n}. That is, 


i=l 


span({a1,...,@n}) 4 [vivam aer}. (7.14) 
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It can be shown that if {£1,..., £n} is a set of n linearly independent vectors, where each x; € R”, 
then span({a@,...,%n}) = R”. In other words, any vector v € R” can be written as a linear 
combination of x; through æn. 

A basis B is a set of linearly independent vectors that spans the whole space, meaning that 
span(B) = R”. There are often multiple bases to choose from, as illustrated in Figure 7.3b. The 
standard basis uses the coordinate vectors e; = (1,0,...,0), up to en = (0,0,...,0,1). This 
lets us translate back and forth between viewing a vector in R? as an either an “arrow in the plane”, 
rooted at the origin, or as an ordered list of numbers (corresponding to the coefficients for each basis 
vector). 


7.1.2.3 Linear maps and matrices 


A linear map or linear transformation is any function f : V —> W such that f(v + w) = 
f(v) + f(w) and f(a v) = a f(v) for all v,w € V. Once the basis of V is chosen, a linear map 
f :V—- W is completely determined by specifying the images of the basis vectors, because any 
element of V can be expressed uniquely as a linear combination of them. 

Suppose V = R” and W = R™. We can compute f(v;) € R™ for each basis vector in V, and 
store these along the columns of an m x n matrix A. We can then compute y = f(x) € R™ for any 
x € R” as follows: 


y= X ayipo’ Omai (7.15) 
j=l j=l 


This corresponds to multiplying the vector æ by the matrix A: 
y = Ax (7.16) 


See Section 7.2 for more details. 
If the function is invertible, we can write 


x= Aty (7.17) 


See Section 7.3 for details. 


7.1.2.4 Range and nullspace of a matrix 


Suppose we view a matrix A € R™*” as a set of n vectors in R™. The range (sometimes also called 
the column space) of this matrix is the span of the columns of A. In other words, 


range(A) = {v € R” : v = Ag, x € R”}. (7.18) 


This can be thought of as the set of vectors that can be “reached” or “generated” by A; it is a 
subspace of R™ whose dimensionality is given by the rank of A (see Section 7.1.4.3). The nullspace 
of a matrix A € R™*” is the set of all vectors that get mapped to the null vector when multiplied 
by A, i.e., 


nullspace(A) = {æ € R” : Ax = 0}. (7.19) 
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Figure 7.4: Visualization of the nullspace and range of an m x n matrix A. Here yı = Axı and y2 = Axa, 
so yi and y2 are in the range of A (are reachable from some x). Also Azz = 0 and Azz = 0, so x2 and x3 
are in the nullspace of A (get mapped to 0). We see that the range is often a subset of the input domain of 
the mapping. 


The span of the rows of A is the complement to the nullspace of A. 
See Figure 7.4 for an illustration of the range and nullspace of a matrix. We shall discuss how to 
compute the range and nullspace of a matrix numerically in Section 7.5.4 below. 


7.1.2.5 Linear projection 


The projection of a vector y E€ R™ onto the span of {£1,..., £n} (here we assume x; € R”) is the 

vector v E€ span({x1,...,%p,}) , such that v is as close as possible to y, as measured by the Euclidean 

norm ||v — y||2. We denote the projection as Proj(y; {a1,...,%n}) and can define it formally as 
Proj(y; {x1, tt ,En}) = argminyespan({æ1,.. æn} IY E v||2. (7.20) 


Given a (full rank) matrix A € R™*” with m > n, we can define the projection of a vector y € R™ 
onto the range of A as follows: 


Proj(y; A) = argminyeray||v — yll = A(ATA) TAT . (7.21) 


These are the same as the normal equations from Section 11.2.2.2. 


7.1.3 Norms of a vector and matrix 


In this section, we discuss ways of measuring the “size” of a vector and matrix. 


7.1.3.1 Vector norms 


A norm of a vector ||æ|| is, informally, a measure of the “length” of the vector. More formally, a 
norm is any function f : R” — R that satisfies 4 properties: 


e For all x € R”, f(x) > 0 (non-negativity). 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


7.1. Introduction 233 


e f(x) = 0 if and only if x = 0 (definiteness). 
e For alla € R”, t ER, f(ta) = |t| f(x) (absolute value homogeneity). 
e For all z,y E€ R”, f(æ +y) < f(x) + f(y) (triangle inequality). 
Consider the following common examples: 
p-norm ||x||, = ©; \x,|P)'/”, for p> 1. 
2-norm ||x||2 = VX; x2, also called Euclidean norm. Note that ||æ||} = z". 
1-norm ||æļ = X; |i]. 


Max-norm ||z||. = max;|x;|. 


O-norm ||x||o = X; I (\z;| > 0). This is a pseudo norm, since it does not satisfy homogeneity. 
It counts the number of non-zero elements in x. If we define 0° = 0, we can write this as 

_y z0 

lællo = Does 27- 


7.1.3.2 Matrix norms 


Suppose we think of a matrix A € R™*” as defining a linear function f(x) = Aw. We define the 
induced norm of A as the maximum amount by which f can lengthen any unit-norm input: 


A 
|Allp = ae sc eee \|Aa||, (7.22) 
240 ||ællp  lælļ=1 


Typically p = 2, in which case 


|| A]|2 =y Amax(ATA) = MAET (7.23) 


where Amax(M) is the largest eigenvalue of M, and øg; is the 7’th singular value. 
The nuclear norm, also called the trace norm, is defined as 


[|Allx = tr(VATA) = X` o; (7.24) 
where vV ATA is the matrix square root. Since the singular values are always non-negative, we have 
Alle = 52 oil = lola (7.25) 


Using this as a regularizer encourages many singular values to become zero, resulting in a low rank 
matrix. More generally, we can define the Schatten p-norm as 


1/p 
Ally = bs ota) (7.26) 
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If we think of a matrix as a vector, we can define the matrix norm in terms of a vector norm, 
||A|| = ||vec(A)||. If the vector norm is the 2-norm, the corresponding matrix norm is the Frobenius 


2, = y[tx(ATA) = IIvee(A) Ila (7.27) 


If A is expensive to evaluate, but Av is cheap (for a random vector v), we can create a stochastic 
approximation to the Frobenius norm by using the Hutchinson trace estimator from Equation (7.37) 
as follows: 


|| All? = tr(ATA) =E [v' A" Av] =E [||Av||3] (7.28) 
where v ~ N(0,1). 


7.1.4 Properties of a matrix 


In this section, we discuss various scalar properties of matrices. 


7.1.4.1 Trace of a square matrix 


The trace of a square matrix A € R”*", denoted tr(A), is the sum of diagonal elements in the 
matrix: 


The trace has the following properties, where c € R is a scalar, and A,B € R”*” are square 
matrices: 


tr(A) = tr(A') (7.30) 
tr(A + B) = tr(A) + tr(B) (7.31) 
tr(cA) = c tr(A) (7.32) 
tr(AB) = tr(BA) (7.33) 
tr(A) = > A; where A; are the eigenvalues of A (7.34) 


We also have the following important cyclic permutation property: For A,B,C such that ABC 
is square, 


tr(ABC) = tr(BCA) = tr(CAB) (7.35) 
From this, we can derive the trace trick, which rewrites the scalar inner product £x" Aa as follows 
x' Aw = tr(a' Ax) = tr(aa' A) (7.36) 
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In some cases, it may be expensive to evaluate the matrix A, but we may be able to cheaply 
evaluate matrix-vector products Av. Suppose v is a random vector such that E [vv] =I. In this 
case, we can create a Monte Carlo approximation to tr(A) using the following identity: 


tr(A) = tr(AE [vv']) = E [tr(Avv')] = E[tr(v' Av)] =E[v' Av] (7.37) 


This is called the Hutchinson trace estimator [Hut90]. 


7.1.4.2 Determinant of a square matrix 


The determinant of a square matrix, denoted det(A) or |A|, is a measure of how much it changes 
a unit volume when viewed as a linear transformation. (The formal definition is rather complex and 
is not needed here.) 

The determinant operator satisfies these properties, where A, B € R"*” 


|A| =|A‘| (7.38) 
|cA| =c"|A| (7.39) 
|AB| = |A||B| (7.40) 
|A| = 0 iff A is singular (7.41) 
|AT}]| = 1/|A| if A is not singular (7.42) 
|A| = Il A; where A; are the eigenvalues of A (7.43) 


i=l 


For a positive definite matrix A, we can write A = LL’, where L is the lower triangular Cholesky 
decomposition. In this case, we have 


det(A) = det(L) det(L') = det(L)? (7.44) 


log det(A) = 2 log det(L) = 2 log | [ Lii = 2tr(log(diag(L))) (7.45) 


7.1.4.3 Rank of a matrix 


The column rank of a matrix A is the dimension of the space spanned by its columns, and the row 
rank is the dimension of the space spanned by its rows. It is a basic fact of linear algebra (that can be 
shown using the SVD, discussed in Section 7.5) that for any matrix A, columnrank(A) = rowrank(A), 
and so this quantity is simply referred to as the rank of A, denoted as rank(A). The following are 
some basic properties of the rank: 


e For A € R”*”, rank(A) < min(m,n). If rank(A) = min(m,n), then A is said to be full rank, 
otherwise it is called rank deficient. 


e For A € R™*", rank(A) = rank(A’) = rank(AT A) = rank(AA'). 


e For Ac R”™™*”, Be R”?, rank(AB) < min(rank(A), rank(B)). 
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e For A,B € R™*”, rank(A + B) < rank(A) + rank(B). 


One can show that a square matrix is invertible iff it is full rank. 


7.1.4.4 Condition numbers 


The condition number of a matrix A is a measure of how numerically stable any computations 
involving A will be. It is defined as follows: 


(A) = || AI] - AT] (7.46) 


where ||A|| is the norm of the matrix. We can show that «(A) > 1. (The condition number depends 
on which norm we use; we will assume the ¢:-norm unless stated otherwise.) 

We say A is well-conditioned if «(A) is small (close to 1), and ill-conditioned if «(A) is large. A 
large condition number means A is nearly singular. This is a better measure of nearness to singularity 
than the size of the determinant. For example, suppose A = 0.1I100x100. Then det(A) = 1071, 
which suggests A is nearly singular, but «(A) = 1, which means A is well-conditioned, reflecting the 
fact that Ax simply scales the entries of æ by 0.1. 

To get a better understanding of condition numbers, consider the linear system of equations 
Ag = b. If A is non-singular, the unique solution is 2 = A~'b. Suppose we change b to b + Ab; 
what effect will that have on a? The new solution must satisify 


A(x + Ax) =b + Ab (7.47) 
where 
Ag = AT!Ab (7.48) 


We say that A is well-conditioned if a small Ab results in a small Ag; otherwise we say that A is 
ill-conditioned. 
For example, suppose 


me 1 1 4. /1=101 19% 
=| (, +1071 1- io) = G +1010 —101° (7.49) 
The solution for b = (1,1) is æ = (1,1). If we change b by Ab, the solution changes to 
Aa = A~'Ab= a o e a) 


Abı + 101°(Ab; — Abs) ee) 


So a small change in b can lead to an extremely large change in x, because A is ill-conditioned 
(k(A) = 2 10°). 

In the case of the f2-norm, the condition number is equal to the ratio of the largest to smallest 
singular values (defined in Section 7.5); furthermore, the singular values of A are the square roots of 
the eigenvalues of ATA, and so 


K(A) = Omaxr/Omin = a (7.51) 


Amin 
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We can gain further insight into condition numbers by considering a quadratic objective function 
f(x) = x! Az. If we plot the level set of this function, it will be elliptical, as shown in Section 7.4.4. 
As we increase the condition number of A, the ellipses become more and more elongated along certain 
directions, corresponding to a very narrow valley in function space. If k = 1 (the minimum possible 
value), the level set will be circular. 


7.1.5 Special types of matrices 


In this section, we will list some common kinds of matrices with various forms of structure. 


7.1.5.1 Diagonal matrix 


A diagonal matrix is a matrix where all non-diagonal elements are 0. This is typically denoted 
D = diag(d1, d2,...,d,), with 


dı 
d2 
D = l (7.52) 


dn 


The identity matrix, denoted I € R"*”, is a square matrix with ones on the diagonal and zeros 
everywhere else, I = diag(1,1,...,1). It has the property that for all A € R”*”, 


AI=A=IA (7.53) 


where the size of I is determined by the dimensions of A so that matrix multiplication is possible. 
We can extract the diagonal vector from a matrix using d = diag(D). We can convert a vector 
into a diagonal matrix by writing D = diag(d). 
A block diagonal matrix is one which contains matrices on its main diagonal, and is 0 everywhere 
else, e.g., 


E S) (7.54) 


A band-diagonal matrix only has non-zero entries along the diagonal, and on k sides of the 
diagonal, where k is the bandwidth. For example, a tridiagonal 6 x 6 matrix looks like this: 


Aua Jie. O > © 
Az, Azz A23 l 
O Aso As3 Asa (7.55) 
As3 As As 0 
: : Ass Ass Ase 
Oo ca aan “Hh Age Aei 
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7.1.5.2 Triangular matrices 


An upper triangular matrix only has non-zero entries on and above the diagonal. A lower 
triangular matrix only has non-zero entries on and below the diagonal. 

Triangular matrices have the useful property that the diagonal entries of A are the eigenvalues of 
A, and hence the determinant is the product of diagonal entries: det(A) = [], Ai. 


7.1.5.3 Positive definite matrices 


Given a square matrix A € R”*” and a vector x € R”, the scalar value x’ Ag is called a quadratic 
form. Written explicitly, we see that 


i=1 j=1 


Note that, 
T T T TAT t/l LT 
xr Ax = (x Az) =x A t=T (SA+ 5A jx (7.57) 


For this reason, we often implicitly assume that the matrices appearing in a quadratic form are 
symmetric. 
We give the following definitions: 


e A symmetric matrix A € S” is positive definite iff for all non-zero vectors x € R”, æ" Ag > 0. 
This is usually denoted A > 0 (or just A > 0). If it is possible that æ" Ax = 0, we say the matrix 
is positive semidefinite or psd. We denote the set of all positive definite matrices by S}. 


e A symmetric matrix A € S” is negative definite, denoted A < 0 (or just A < 0) iff for all 
non-zero x € R”, «' Aw < 0. If it is possible that 2’ Aa = 0, we say the matrix is negative 
semidefinite. 


e A symmetric matrix A € S” is indefinite, if it is neither positive semidefinite nor negative 
semidefinite — i.e., if there exists £1, £2 E€ R” such that xi Aa, > 0 and ri Ars <0. 


It should be obvious that if A is positive definite, then —A is negative definite and vice versa. 
Likewise, if A is positive semidefinite then —A is negative semidefinite and vice versa. If A is 
indefinite, then so is —A. It can also be shown that positive definite and negative definite matrices 
are always invertible. 

In Section 7.4.3.1, we show that a symmetric matrix is positive definite iff its eigenvalues are 
positive. Note that if all elements of A are positive, it does not mean A is necessarily positive definite. 


For example, A = (; ) is not positive definite. Conversely, a positive definite matrix can have 


2 —-1 
-1 2 

A sufficient condition for a (real, symmetric) matrix to be positive definite is that it is diagonally 
dominant, i.e., if in every row of the matrix, the magnitude of the diagonal entry in that row is 


negative entries e.g., A = 
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larger than the sum of the magnitudes of all the other (non-diagonal) entries in that row. More 
precisely, 


laul > So laij] for all å (7.58) 
j+i 
In 2d, any real, symmetric 2 x 2 matrix A : is positive definite iff a > 0, d > 0 and ad > b?. 


Finally, there is one type of positive definite matrix that comes up frequently, and so deserves 
some special mention. Given any matrix A € R™%*” (not necessarily symmetric or even square), the 
Gram matrix G = A'A is always positive semidefinite. Further, if m > n (and we assume for 
convenience that A is full rank), then G = A'A is positive definite. 


7.1.5.4 Orthogonal matrices 


Two vectors x,y € R” are orthogonal if x'y = 0. A vector x € R” is normalized if ||æ||2 = 1. A 
set of vectors that is pairwise orthogonal and normalized is called orthonormal. A square matrix 
U € R”*” is orthogonal if all its columns are orthonormal. (Note the different meaning of the term 
orthogonal when talking about vectors versus matrices.) If the entries of U are complex valued, we 
use the term unitary instead of orthogonal. 

It follows immediately from the definition of orthogonality and normality that U is orthogonal iff 


UU =I = UU". (7.59) 


In other words, the inverse of an orthogonal matrix is its transpose. Note that if U is not square — 
ie., U € R™*", n <m — but its columns are still orthonormal, then UTU =I, but UUT Æ I. We 
generally only use the term orthogonal to describe the previous case, where U is square. 

An example of an orthogonal matrix is a rotation matrix (see Exercise 7.1). For example, a 
rotation in 3d by angle a about the z axis is given by 


cos(a) —sin(a) 0 
R(a) = | sin(a) cos(a) 0 (7.60) 
0 0 1 


If a = 45°, this becomes 


1 1 
wy 
0 0 1 


where Z = 0.7071. We see that R(—a) = R(a)~! = R(a)', so this is an orthogonal matrix. 


One nice property of orthogonal matrices is that operating on a vector with an orthogonal matrix 
will not change its Euclidean norm, i.e., 
[Uz] = |æll2 (7.62) 


for any nonzero æ € R”, and orthogonal U € R”*”. 
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Similarly, one can show that the angle between two vectors is preserved after they are transformed 
by an orthogonal matrix. The cosine of the angle between a and y is given by 


o su 7.63 
costal, y) = Teyi eel 
cos(a(Ua, Uy)) = (Ue) Uy) > ed = cos(a(x, y)) (7.64) 


- [[Ux|||/Ty||  llællllyli 


In summary, transformations by orthogonal matrices are generalizations of rotations (if det(U) = 1) 
and reflections (if det(U) = —1), since they preserve lengths and angles. 

Note that there is a technique called Gram Schmidt orthogonalization which is a way to make any 
square matrix orthogonal, but we will not cover it here. 


7.2 Matrix multiplication 


The product of two matrices A € R™*” and B € R”*? is the matrix 


C = AB E€ RP, (7.65) 
where 
Cij = 5 AikBkj. (7.66) 
k=1 


Note that in order for the matrix product to exist, the number of columns in A must equal the 
number of rows in B. 

Matrix multiplication generally takes O(mnp) time, although faster methods exist. In addition, 
specialized hardware, such as GPUs and TPUs, can be leveraged to speed up matrix multiplication 
significantly, by performing operations across the rows (or columns) in parallel. 

It is useful to know a few basic properties of matrix multiplication: 


e Matrix multiplication is associative: (AB)C = A(BC). 
e Matrix multiplication is distributive: A(B + C) = AB + AC. 
e Matrix multiplication is, in general, not commutative; that is, it can be the case that AB 4 BA. 
(In each of the above cases, we are assuming that the dimensions match.) 
There are many important special cases of matrix multiplication, as we discuss below. 
7.2.1 Vector—vector products 


Given two vectors x,y € R”, the quantity 2'y, called the inner product, dot product or scalar 
product of the vectors, is a real number given by 


(x,y) Salty = X ziyi. (7.67) 


i=l 
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Note that it is always the case that a'y = y! x. 
Given vectors x € R™, y € R” (they no longer have to be the same size), æy" is called the outer 
product of the vectors. It is a matrix whose entries are given by (xy! )ij = ZYyj, Le., 


| T1Y1 = %1Y2 > TiYn ] 
T2Y1 T2Y2 "°° T2Y 
ry eR?" =| — Me (7.68) 
ZmY1 LmY2 "°° LmYn 


7.2.2 Matrix—vector products 


Given a matrix A € R™*” and a vector x € R”, their product is a vector y = Aa € R™. There are 
a couple of ways of looking at matrix-vector multiplication, and we will look at them both. 
If we write A by rows, then we can express y = Aa as follows: 


T T 


— al — alx 
— a} — alx 
y = Ar = . z= : ; (7.69) 
T T 
— al — al, £ 


In other words, the ith entry of y is equal to the inner product of the ith row of A and g, y; = al æ. 


Alternatively, let’s write A in column form. In this case we see that 


n 


In other words, y is a linear combination of the columns of A, where the coefficients of the linear 
combination are given by the entries of æ. We can view the columns of A as a set of basis vectors 
defining a linear subspace. We can construct vectors in this subspace by taking linear combinations 
of the basis vectors. See Section 7.1.2 for details. 


7.2.3 Matrix—matrix products 


Below we look at four different (but, of course, equivalent) ways of viewing the matrix-matrix 
multiplication C = AB. 

First we can view matrix-matrix multiplication as a set of vector-vector products. The most 
obvious viewpoint, which follows immediately from the definition, is that the 7,7 entry of C is equal 
to the inner product of the ith row of A and the jth column of B. Symbolically, this looks like the 
following, 


— ai — alby albo -> alb, 
— a} — | | | abı alb >- abby 
C=AB= bi be bp | = . : (7.71) 
| | | : 
— a, — al, bı al,b2 al, bp 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


242 Chapter 7. Linear Algebra 


Figure 7.5: Illustration of matrix multiplication. From https://en. wikipedia. org/wiki/Matriz_ 
multiplication. Used with kind permission of Wikipedia author Bilou. 


Remember that since A € R”*” and B € R"*?, a; € R” and bj € R”, so these inner products 
all make sense. This is the most “natural” representation when we represent A by rows and B by 
columns. See Figure 7.5 for an illustration. 

Alternatively, we can represent A by columns, and B by rows, which leads to the interpretation of 
AB as a sum of outer products. Symbolically, 


— b! = 
B Bin: mE 
C=AB=| aq, a — an . =X a;b] . (7.72) 
|| Jj ¿l| # 


Put another way, AB is equal to the sum, over all 7, of the outer product of the ith column of A 
and the ith row of B. Since, in this case, a; € R™ and b; € R”, the dimension of the outer product 
aib! is m x p, which coincides with the dimension of C. 

We can also view matrix-matrix multiplication as a set of matrix-vector products. Specifically, if 
we represent B by columns, we can view the columns of C as matrix-vector products between A and 
the columns of B. Symbolically, 


| | | | | | 
C=AB=A| b by © b, | =| Ab) Ab ~ Ab, |. (7.73) 


Here the ith column of C is given by the matrix-vector product with the vector on the right, c; = Ab;. 
These matrix-vector products can in turn be interpreted using both viewpoints given in the previous 
subsection. 

Finally, we have the analogous viewpoint, where we represent A by rows, and view the rows of C 
as the matrix-vector product between the rows of A and the matrix B. Symbolically, 


— al — — alB — 
— a} — — alB — 

C=AB= ; B= . : (7.74) 
— a, — — alB 
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Here the ith row of C is given by the matrix-vector product with the vector on the left, c! = a! B. 

It may seem like overkill to dissect matrix multiplication to such a large degree, especially when all 
these viewpoints follow immediately from the initial definition we gave (in about a line of math) at 
the beginning of this section. However, virtually all of linear algebra deals with matrix multiplications 
of some kind, and it is worthwhile to spend some time trying to develop an intuitive understanding 
of the viewpoints presented here. 

Finally, a word on notation. We write A? as shorthand for AA, which is the matrix product. To 
denote elementwise squaring of the elements of a matrix, we write A®? = [A?,]. (If A is diagonal, 
then A? = A©?.) 

We can also define the inverse of A? using the matrix square root: we say A = VM if A? = M. 
To denote elementwise square root of the elements of a matrix, we write [,/Mj,]. 


7.2.4 Application: manipulating data matrices 


As an application of the above results, consider the case where X is the N x D design matrix, whose 
rows are the data cases. There are various common preprocessing operations that we apply to this 
matrix, which we summarize below. (Writing these operations in matrix form is useful because it is 
notationally compact, and it allows us to implement the methods quickly using fast matrix code.) 


7.2.4.1 Summing slices of the matrix 


Suppose X is an N x D matrix. We can sum across the rows by premultiplying by a 1 x N matrix 
of ones to create a 1 x D matrix: 


1X = (nanm © natap) (7.75) 


Hence the mean of the data vectors is given by 


1 
T = viwx (7.76) 


We can sum across the columns by postmultiplying by a D x 1 matrix of ones to create a N x 1 
matrix: 


Dia Zid 
X1p = : (7.77) 


Soa UNd 
We can sum all entries in a matrix by pre and post multiplying by a vector of 1s: 
INX1p =)_ Xi; (7.78) 
ij 


Hence the overall mean is given by 


1 
T= ypivXlp (7.79) 
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7.2.4.2 Scaling rows and columns of a matrix 


We often want to scale rows or columns of a data matrix (e.g., to standardize them). We now show 
how to write this in matrix notation. 

If we pre-multiply X by a diagonal matrix S = diag(s), where s is an N-vector, then we just scale 
each row of X by the corresponding scale factor in s: 


Sı oct 0 Zil ct: TID §1%1 1 0 7° $12%1,D 
diag(s)X = om EA = 7 (7.80) 
0O > SN TN1ı ` ND SNZNI *** SNTN,D 


If we post-multiply X by a diagonal matrix S = diag(s), where s is a D-vector, then we just scale 
each column of X by the corresponding element in s. 


Zii ct Lip si + 0 $111 > SDT1,D 

Xdiag(s) = Sha S = S (7.81) 
ZN ` TND 0 = sp SIGNI ` SDXN,D 

Thus we can rewrite the standardization operation from Section 10.2.8 in matrix form as follows: 


standardize(X) = (X — 1yp”)diag(a)~! (7.82) 


where u = T is the empirical mean, and ø is a vector of the empirical standard deviations. 


7.2.4.3 Sum of squares and scatter matrix 
The sum of squares matrix is D x D matrix defined by 


2 
Trt sth Ln, 1Tn,D 


N N 
Ayly — T . 
Soê XX =X ea) 7 (7.83) 
n=1 n=1 En, DEn 5 2. p 


The scatter matrix is a D x D matrix defined by 


N 


Sz £ X_ (æn — T)(£n — Z)" = (= za!) — Nga" (7.84) 


n=1 


We see that this is the sum of squares matrix applied to the mean-centered data. More precisely, 
define X to be a version of X where we subtract the mean % = ŁX'1 N Off every row. Hence we can 
compute the centered data matrix using 


7 1 
X=X-1ye' =X- yy lwiwX = CyX (7.85) 

where 
Cy £1 “4 7.86 
N SINT N N ( : ) 
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is the centering matrix, and Jy = 1y1\) is a matrix of all 1s. The scatter matrix can now be 
computed as follows: 


Sz = X'X = X'C)CyX = X'CyX (7.87) 


where we exploited the fact that Cy is symmetric and idempotent, i.e., Ck, = Cy for k = 1,2,... 
(since once we subtract the mean, subtracting it again has no effect). 


7.2.4.4 Gram matrix 


The N x N matrix XX" is a matrix of inner products called the Gram matrix: 


xia) mi xian 
K2xx'= Be (7.88) 
xl ay tee THEN 


Sometimes we want to compute the inner products of the mean-centered data vectors, K = XXT. 
However, if we are working with a feature similarity matrix instead of raw features, we will only have 
access to K, not X. (We will see examples of this in Section 20.4.4 and Section 20.4.6.) Fortunately, 
we can compute K from K using the double centering trick: 

1 1 1 

JK KJ 
N N ii N 
This subtracts the row means and column means from K, and adds back the global mean that gets 
subtracted twice, so that both row means and column means of K are equal to zero. 

To see why Equation (7.89) is true, consider the scalar form: 


K = XX' =CyKCy =K 


51'K1 (7.89) 


> Š i 
Kij = 82) = (2i - DP) 2 o> zı) (7.90) 
EN 18 INN 
=g] zj — N 5 xl ap, — N TIE ao 5 X sha (7.91) 
k=1 k=1 k=11=1 


7.2.4.5 Distance matrix 


Let X be N, x D datamatrix, and Y be another N, x D datamatrix. We can compute the squared 
pairwise distances between these using 


Dig = (zi — yy)" (zi — yy) = |læ:ll? — 22i y; + Ilys? (7.92) 
Let us now write this in matrix form. Let # = [||a1||?;--- ;||av,||?] = diag(XX") be a vector where 
each element is the squared norm of the examples in X, and define y similarly. Then we have 

D = ĉl}, —-2XY'+1y,9" (7.93) 
In the case that X = Y, we have 

D = #1) —-2XX' +1ya" (7.94) 


This vectorized computation is often much faster than using for loops. 
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7.2.5 Kronecker products * 


If A is an m xn matrix and B is a p x q matrix, then the Kronecker product A ®B is the mp x nq 
block matrix 


ayı B se ain B 
A®B= as (7.95) 


Qm1B © amnB 
For example, 


44104, @11bi2 @11bı3 @1i2b11 @12b12 @12b13 
a11b21 @11b22 @11b23 @12b21 @12b22 @12b23 


ee bir bi2 bag a a21b11 @21b12 @21b13 @22b11 a@22b12 a22b13 
a21 G22} ® = (7.96) 

a ü b21 b22 b23 a21b21 @21b22 a@21b23 @22b21 a22b22 a22b23 

sa a2 a3ıbıı1 @3ıbı2 @31ı1bı3 @32b11 @32b12 @32b13 

431691 @31b22 a@31b23 @32b21 a32b22 4a32b23 

Here are some useful identities: 

(A@B)'=A'@B" (7.97) 
(A @ B)vec(C) = vec(BCA') (7.98) 


where vec(M) stacks the columns of M. (If we stack along the rows, we get (A ®B)vec(C) = 
vec(ACB').) See [Loa00] for a list of other useful properties. 


7.2.6 Einstein summation * 


Einstein summation, or einsum for short, is a notational shortcut for working with tensors. The 
convention was introduced by Einstein [Ein16, sec 5], who later joked to a friend, “I have made a great 
discovery in mathematics; I have suppressed the summation sign every time that the summation 
must be made over an index which occurs twice...” [Pai05, p.216]. For example, instead of writing 
matrix multiplication as Cj; = 5°, Aik Bkj, we can just write it as Cj; = AikBkj, where we drop the 
at 

As a more complex example, suppose we have a 3d tensor Sntk where n indexes examples in the 
batch, t indexes locations in the sequence, and k indexes words in a one-hot representation. Let 
Wea be an embedding matrix that maps sparse one-hot vectors R? to dense vectors in R?. We can 
convert the batch of sequences of one-hots to a batch of sequences of embeddings as follows: 


Enta = X SntkWrea (7.99) 
k 


We can compute the sum of the embedding vectors for each sequence (to get a global representation 
of each bag of words) as follows: 


Ena = 5 5 Sntk Wkra (7.100) 
k t 
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Finally we can pass each sequence’s vector representation through another linear transform Va. to 
map to the logits over a classifier with c labels: 


Lnc = > EnaVac = 5 5 5 Sntk Wka Vac (7.101) 
d d k t 


In einsum notation, we write Lne = SntkWkaVac- We sum over k and d because those indices occur 
twice on the RHS. We sum over t because that index does not occur on the LHS. 

Einsum is implemented in NumPy, Tensorflow, PyTorch, ete. What makes it particularly useful is 
that it can perform the relevant tensor multiplications in complex expressions in an optimal order, 
so as to minimize time and intermediate memory allocation.? The library is best illustrated by the 
examples in einsum_demo.ipynb. 

Note that the speed of einsum depends on the order in which the operations are performed, which 
depends on the shapes of the relevant arguments. The optimal ordering minimizes the treewidth 
of the resulting computation graph, as explained in [GASG18]. In general, the time to compute 
the optimal ordering is exponential in the number of arguments, so it is common to use a greedy 
approximation. However, if we expect to repeat the same calculation many times, using tensors of 
the same shape but potentially different content, we can compute the optimal ordering once and 
reuse it multiple times. 


7.3 Matrix inversion 


In this section, we discuss how to invert different kinds of matrices. 


7.3.1 The inverse of a square matrix 
The inverse of a square matrix A € R”*” is denoted A71, and is the unique matrix such that 
A'A=I=AA™. (7.102) 


Note that A~+ exists if and only if det(A) 4 0. If det(A) = 0, it is called a singular matrix. 
The following are properties of the inverse; all assume that A,B € R”*” are non-singular: 


(AT) =A (7.103) 
(AB) t = BA! (7.104) 
(A) = (AT) £ AT (7.105) 


For the case of a 2 x 2 matrix, the expression for A~! is simple enough to give explicitly. We have 


= a b = 1 d —b 
A= é i) A FN] (£ P ) (7.106) 
For a block diagonal matrix, the inverse is obtained by simply inverting each block separately, e.g., 


G o) = ca ga] (7.107) 


2. These optimizations are implemented in the opt-einsum library [GASG18]. Its core functionality is included in 
NumPy and JAX einsum functions, provided you set optimize=True parameter. 
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7.3.2 Schur complements * 


In this section, we review some useful results concerning block structured matrices. 


Theorem 7.3.1 (Inverse of a partitioned matrix). Consider a general partitioned matriz 


E F 
M=(G n) 
where we assume E and H are invertible. We have 


M-1 = ( (M/H)~* -(M/H) FH! ) 
-H-1G(M/H)-! H-!+H-!G(M/H)-'FH"! 
7 = + E-IF(M/E) GE! E 
E -(M/E) GE! (M/E)~* 


where 


M/H*E-FH'G 
M/E £ H — GE !F 


(7.108) 


(7.109) 


(7.110) 


(7.111) 
(7.112) 


We say that M/H is the Schur complement of M wrt H, and M/E is the Schur complement of 


M wrt E. 


Equation (7.109) and Equation (7.110) are called the partitioned inverse formulae. 


Proof. If we could block diagonalize M, it would be easier to invert. To zero out the top right block 


of M we can pre-multiply as follows 
I -FH"'\ /E F)\_ /E-FH'G 0 
0 I G H) G H 
Similarly, to zero out the bottom left we can post-multiply as follows 
E-FH!G 0 I 0\ /(E-FH'G 0 
G Hj) -HG I) 0 H 
Putting it all together we get 


I —-FH"')\ /(E F I 0\ /(E-FH'G 0 
0 I G H/\-H'G IJ 0 H 
—1——— Aaaa c amaa a aa —$—<—_ ama 


WwW 


Taking the inverse of both sides yields 


Z-!M X! = Ww 
M~! = ZWX 
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(7.114) 


(7.115) 


(7.116) 
(7.117) 
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Substituting in the definitions we get 


E - = e D os ey b a (7.118) 
7 a ey i — (7.119) 


= (M/H) -(M/H) FH! (7.120) 
—H-'G(M/H)-! H-! + H-!'G(M/H)-!FH-! 
Alternatively, we could have decomposed the matrix M in terms of E and M/E = (H — GE™'F), 

yielding 
E F)\ | /(E-!+E-'F(M/E)-"!GE-! —E-!F(M/E)~! (7.121) 
G H E —(M/E)-'GE"! (M/E)! i 


7.3.3 The matrix inversion lemma * 


Equating the top left block of the first matrix in Equation (7.119) with the top left block of the 
matrix in Equation (7.121) 


(M/H) = (E — FH tG) = E! + E'F(H — GE“! F)"'GE“' (7.122) 


This is known as the matrix inversion lemma or the Sherman-Morrison-Woodbury formula. 

A typical application in machine learning is the following. Let X be an N x D data matrix, and 
£ be N x N diagonal matrix. Then we have (using the substitutions E = ©, F = G! = X, and 
H-! = —I) the following result: 


(© +XX!) t= 5 -Y H X(I+X'X X) {x's (7.123) 
The LHS takes O(N?) time to compute, the RHS takes time O(D?) to compute. 


Another application concerns computing a rank one update of an inverse matrix. Let E = A, 
F = u, G = v! , and H = —1. Then we have 


(A +uv')™! = A7! + ATtu(—1 — v' Atu) twl AT! (7.124) 
A`luv' AT! 

= At- > 7.125 

1+v'A-lu ( ) 


This is known as the Sherman-Morrison formula. 


7.3.4 Matrix determinant lemma * 


We now use the above results to derive an efficient way to compute the determinant of a block- 
structured matrix. 
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From Equation (7.115), we have 


|X|/M||Z| = |W] = |E - FH” 'G||H| 


(& q)/=B-FHGIH| 


GH 
IM] = |M/H||H| 
IM] 
M/H| = — 
|IM/ H 


So we can see that M/H acts somewhat like a division operator (hence the notation). 


Furthermore, we have 


[M| = |M/Ħ||Ħ]| = |M/E||5| 
|IM/E]|E| 

|H] 
|E — FH~'G| = |H — GE“'F||H“'||E| 


|M/H| = 


Hence (setting E = A, F = —u, G = v', H = 1) we have 
|A +uv'| = (1+ v'A!u)|Al 


This is known as the matrix determinant lemma. 


7.3.5 Application: deriving the conditionals of an MVN * 


Consider a joint Gaussian of the form p(æ1, £2) = N (x|u, ©), where 
— (Mf Sa Xiu Vie 
i ie , & X22 
In Section 3.2.3, we claimed that 


plæijæ2) = N (a1 |My + £12837 (£2 — Ho), Bar — X12532 X21) 


In this section, we derive this result using Schur complements. 
Let us factor the joint p(x1, £2) as p(a2)p(a1|a@2) as follows: 


T 4 
1 fay — py Xu %2 zı — py 
p(£1, £2) X aof 2 & — m Sy Bas £2 — Hə 
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(7.130) 
(7.131) 


(7.132) 


(7.133) 


(7.134) 


(7.135) 


(7.136) 


7.4. Eigenvalue decomposition (EVD) 251 


Using Equation (7.118) the above exponent becomes 


= 
1 (a1 — py, I 0\ (E/E) 0 
p(z, £2) X exp {-3 & = a eo I 0 x (7.137) 
I -Si209)\ (#1 -m 
x (o I bases ae (7.138) 
1 = = 
= exp | —5(01 — pı — Bia Bef (a — p) (E/22) (7.139) 
= 1 = 
(a1 — p — En Ez (ae — HM) } x exp {-3(e2 - fy)" Ez (£2 — ma)} (7.140) 


This is of the form 
exp(quadratic form in %1, £2) x exp(quadratic form in æ2) (7.141) 
Hence we have successfully factorized the joint as 


p(x1, £2) = p(xı|£2)p(£2) (7.142) 
= N (x1ı| H12; E12) N (£2|H2, X22) (7.143) 


where the parameters of the conditional distribution can be read off from the above equations using 


Hij = by + D12D99' (£2 — Ho) (7.144) 

Xij = X/X2 = X11 — Sigh X21 (7.145) 

We can also use the fact that |M| = |M/H]| |H] to check the normalization constants are correct: 
(27) (41 +42)/21 53 = (2r) +42)/2 (1) /Doo| \Doa|)2 (7.146) 

= (2r)"/?|H/Do0|? (Qr)2/?|Do9| 2 (7.147) 


where dı = dim(a ) and dz = dim(a2). 


7.4 Eigenvalue decomposition (EVD) 


In this section, we review some standard material on the eigenvalue decomposition or EVD of 
square (real-valued) matrices. 


7.4.1 Basics 


Given a square matrix A € R”*”, we say that à € R is an eigenvalue of A and u € R” is the 
corresponding eigenvector if 


Au = àu, uF0. (7.148) 
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Intuitively, this definition means that multiplying A by the vector u results in a new vector that 
points in the same direction as u, but is scaled by a factor A. For example, if A is a rotation matrix, 
then u is the axis of rotation and à = 1. 

Note that for any eigenvector u € R”, and scalar c € R, 


A(cu) = cAu = càu = X(cu) (7.149) 


Hence cu is also an eigenvector. For this reason when we talk about “the” eigenvector associated 
with À, we usually assume that the eigenvector is normalized to have length 1 (this still creates some 
ambiguity, since u and —w will both be eigenvectors, but we will have to live with this). 

We can rewrite the equation above to state that (A, a) is an eigenvalue-eigenvector pair of A if 


(I—A)u=0, u40. (7.150) 


Now (AI — A)u = 0 has a non-zero solution to u if and only if (AI — A) has a non-empty nullspace, 
which is only the case if (AI — A) is singular, i.e., 


det(AL— A) =0 . (7.151) 


This is called the characteristic equation of A. (See Exercise 7.2.) The n solutions of this equation 

are the n (possibly complex-valued) eigenvalues \;, and u; are the corresponding eigenvectors. It is 

standard to sort the eigenvectors in order of their eigenvalues, with the largest magnitude ones first. 
The following are properties of eigenvalues and eigenvectors. 


e The trace of a matrix is equal to the sum of its eigenvalues, 
tr(A)= YA (7.152) 
i=1 


e The determinant of A is equal to the product of its eigenvalues, 
det(A) =] [~ . (7.153) 
i=1 


e The rank of A is equal to the number of non-zero eigenvalues of A. 


e If A is non-singular then 1/; is an eigenvalue of A~! with associated eigenvector uj, i.e., 
Antu; = (1/A;) Uy. 


e The eigenvalues of a diagonal or triangular matrix are just the diagonal entries. 


7.4.2 Diagonalization 


We can write all the eigenvector equations simultaneously as 


AU=UA (7.154) 
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where the columns of U € R”*” are the eigenvectors of A and A is a diagonal matrix whose entries 
are the eigenvalues of A, i.e., 


| | | 
UER?” =| u uz © un |, A= diag(ài,..., Àn). (7.155) 
| | | 


If the eigenvectors of A are linearly independent, then the matrix U will be invertible, so 
A =UAU!. (7.156) 


A matrix that can be written in this form is called diagonalizable. 


7.4.3 Eigenvalues and eigenvectors of symmetric matrices 


When A is real and symmetric, it can be shown that all the eigenvalues are real, and the eigenvectors 
are orthonormal, i.e., uluj =0 ifi Æ j, and ulu; = 1, where u; are the eigenvectors. In matrix 
form, this becomes UTU = UU! = I; hence we see that U is an orthogonal matrix. 

We can therefore represent A as 


Al =< ul — 
|| | do = ae 
A =UAU' = |u, us © Un i i (7.157) 
hi | 5 
Xn - ul - 
| | r 
= |u| (= ul -)+---+rAn|un](-— u, -)= 5 Auju, (7.158) 
| | = 


Thus multiplying by any symmetric matrix A can be interpreted as multiplying by a rotation matrix 
U', a scaling matrix A, followed by an inverse rotation U. 

Once we have diagonalized a matrix, it is easy to invert. Since A = UAU", where UT = UT}, we 
have 


d 
1 
A` = UAU" =X uul 7.159 
This corresponds to rotating, unscaling, and then rotating back. 


7.4.3.1 Checking for positive definiteness 


We can also use the diagonalization property to show that a symmetric matrix is positive definite iff 
all its eigenvalues are positive. To see this, note that 


a’ Ax = z UAU" g = y'Ay = X Au? (7.160) 
j=} 


where y = U! æ. Because y? is always nonnegative, the sign of this expression depends entirely on 
the A,’s. If all A; > 0, then the matrix is positive definite; if all A; > 0, it is positive semidefinite. 
Likewise, if all A; < 0 or à; < 0, then A is negative definite or negative semidefinite respectively. 
Finally, if A has both positive and negative eigenvalues, it is indefinite. 
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X2 


Ay? 
Az? 


X1 


Figure 7.6: Visualization of a level set of the quadratic form (a — p)" A(x — u) in 2d. The major and minor 
azes of the ellipse are defined by the first two eigenvectors of A, namely ui and u2. Adapted from Figure 2.7 
of [Bis06]. Generated by gaussEvec.ipynb. 


7.4.4 Geometry of quadratic forms 


A quadratic form is a function that can be written as 
f(a) =a' Ax (7.161) 


where x € R” and A is a positive definite, symmetric n-by-n matrix. Let A = UAU" be a 
diagonalization of A (see Section 7.4.3). Hence we can write 


f(w) =a Aw =a2'UAU'a2 = y' Ay = 5 ry? (7.162) 
i=1 


where y; = x! u; and À; > 0 (since A is positive definite). The level sets of f(a) define hyper-ellipsoids. 
For example, in 2d, we have 


Ay? + Ayi =P (7.163) 


which is the equation of a 2d ellipse. This is illustrated in Figure 7.6. The eigenvectors determine 
the orientation of the ellipse, and the eigenvalues determine how elongated it is. In particular, the 
major and minor semi-axes of the ellipse satisfy a~? = \; and b7? = Xg. In the case of a Gaussian 
distribution, we have A = ©~', so small values of À; correspond to directions where the posterior 
has low precision and hence high variance. 


7.4.5 Standardizing and whitening data 


Suppose we have a dataset X € RNP. It is common to preprocess the data so that each column has 
zero mean and unit variance. This is called standardizing the data, as we discuss in Section 10.2.8. 
Although standardizing forces the variance to be 1, it does not remove correlation between the 
columns. To do that, we must whiten the data. To define this, let the empirical covariance matrix 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


7.4. Eigenvalue decomposition (EVD) 255 


Raw Standardized 


280 x x x x 


66 68 70 72 74 76 78 l =0.10 -0.05 0.00 0.05 0.10 


(a) (b) 
PCA-whitened ZCA-whitened 
x x x x 
3 3 


(c) (a) 


Figure 7.7: (a) Height/weight data. (b) Standardized. (c) PCA Whitening. (d) ZCA whitening. Numbers refer 
to the first 4 datapoints, but there are 73 datapoints in total. Generated by height_weight_whiten__ plot.ipynb. 


be X = XTX, and let & = EDE! be its diagonalization. Equivalently, let [U, S, V] be the SVD of 


X (so E = V and D = 8’, as we discuss in Section 20.1.3.3.) Now define 
Woea = D~?E" (7.164) 


This is called the PCA whitening matrix. (We discuss PCA in Section 20.1.) Let y = Wpca® be a 
transformed vector. We can check that its covariance is white as follows: 


Cov [y] = WE [za™] WT = WEW" = (D-?E")(EDE')(ED~?) =I (7.165) 


The whitening matrix is not unique, since any rotation of it, W = RW ca, will still maintain the 
whitening property, i.e., WTW = 7’. For example, if we take R = E, we get 


Wea = ED-?E! = 5-2 = VS-!vT (7.166) 
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This is called Mahalanobis whitening or ZCA. (ZCA stands for “zero-phase component analysis”, 
and was introduced in [BS97].) The advantage of ZCA whitening over PCA whitening is that the 
resulting transformed data is as close as possible to the original data (in the least squares sense) 
[Amol17]. This is illustrated in Figure 7.7. When applied to images, the ZCA transformed data 
vectors still look like images. This is useful when the method is used inside a deep learning system 
[KH09]. 


7.4.6 Power method 


We now describe a simple iterative method for computing the eigenvector corresponding to the largest 
eigenvalue of a real, symmetric matrix; this is called the power method. This can be useful when 
the matrix is very large but sparse. For example, it is used by Google’s PageRank to compute 
the stationary distribution of the transition matrix of the world wide web (a matrix of size about 3 
billion by 3 billion!). In Section 7.4.7, we will see how to use this method to compute subsequent 
eigenvectors and values. 

Let A be a matrix with orthonormal eigenvectors u; and eigenvalues |A,| > [A2] > -+ > [Am] > 0, 
so A = UAU". Let vo) be an arbitrary vector in the range of A, so Ax = va) for some æ. Hence 
we can write U9) as 


vo = U(AU'@) = ayu +--+ + amum (7.167) 
for some constants a;. We can now repeatedly multiply v by A and renormalize: 
Ur X Avı (7.168) 


(We normalize at each iteration for numerical stability.) 
Since v is a multiple of Atvo, we have 


Vt X aiàt ui + azàbu2 ++ amàn Um (7.169) 

= x (aru + ay(A2/1)*u2 fee Am (Am/A1)*tm) (7.170) 

—> Ataru (7.171) 
since A < 1 for k > 1 (assuming the eigenvalues are sorted in descending order). So we see that 


this converges to uz, although not very quickly (the error is reduced by approximately |\2/A1| at 
each iteration). The only requirement is that the initial guess satisfy vjw: 4 0, which will be true 
for a random vp with high probability. 

We now discuss how to compute the corresponding eigenvalue, A,;. Define the Rayleigh quotient 
to be 


T 
a £ Ax 
R(A, x) = ais (7.172) 
Hence 
ul Au; Aul ui 
R(A, u;i) = aa = a =x; (7.173) 


Thus we can easily compute A; from uw, and A. See power_method_demo.ipynb for a demo. 
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7.4.7 Deflation 


Suppose we have computed the first eigenvector and value w;, A; by the power method. We now 
describe how to compute subsequent eigenvectors and values. Since the eigenvectors are orthonormal, 
and the eigenvalues are real, we can project out the w; component from the matrix as follows: 


A® = (T—uul)AY = AM — aula = A® — Muul (7.174) 


This is called matrix deflation. We can then apply the power method to A), which will find the 
largest eigenvector/value in the subspace orthogonal to u1. 

In Section 20.1.2, we show that the optimal estimate W for the PCA model (described in 
Section 20.1) is given by the first K eigenvectors of the empirical covariance matrix. Hence deflation 
can be used to implement PCA. It can also be modified to implement sparse PCA [Mac09]. 

7.4.8 Eigenvectors optimize quadratic forms 


We can use matrix calculus to solve an optimization problem in a way that leads directly to 
eigenvalue/eigenvector analysis. Consider the following, equality constrained optimization problem: 


MaXgzern £ As subject to ||æ||2 =1 (7.175) 


for a symmetric matrix A € S”. A standard way of solving optimization problems with equality 
constraints is by forming the Lagrangian, an objective function that includes the equality constraints 
(see Section 8.5.1). The Lagrangian in this case can be given by 


L(£, A) = £ Ax +A(1— g'z) (7.176) 


where À is called the Lagrange multiplier associated with the equality constraint. It can be established 
that for x* to be a optimal point to the problem, the gradient of the Lagrangian has to be zero at x* 
(this is not the only condition, but it is required). That is, 


Vel(x,r) = 2A' g — 2x = 0. (7.177) 


Notice that this is just the linear equation Ax = Aæ. This shows that the only points which can 
possibly maximize (or minimize) æ" Aa assuming w'« = 1 are the eigenvectors of A. 


7.5 Singular value decomposition (SVD) 


We now discuss the SVD, which generalizes EVD to rectangular matrices. 


7.5.1 Basics 


Any (real) m x n matrix A can be decomposed as 


| | 
A=USV' =o; |w | (— vi -)+-:-+0,[u,-](-— v —) (7.178) 
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SSA an o, m-n, on, go U gM. gh y g 
TE i 
n 
m A = U 
(a) (b) 


Figure 7.8: SVD decomposition of a matriz, A = USV". The shaded parts of each matrix are not computed 
in the economy-sized version. (a) Tall skinny matrix. (b) Short wide matriz. 


where U is an m x m whose columns are orthornormal (so UTU = Im), V is n x n matrix whose 
rows and columns are orthonormal (so V'V = VV' =I,,), and S is a m x n matrix containing the 
r =min(m,n) singular values c; > 0 on the main diagonal, with 0s filling the rest of the matrix. 
The columns of U are the left singular vectors, and the columns of V are the right singular vectors. 
This is called the singular value decomposition or SVD of the matrix. See Figure 7.8 for an 
example. 

As is apparent from Figure 7.8a, if m > n, there are at most n singular values, so the last m — n 
columns of U are irrelevant (since they will be multiplied by 0). The economy sized SVD, also 
called a thin SVD, avoids computing these unnecessary elements. In other words, if we write the U 
matrix as U = [Uj, U2], we only compute U1. Figure 7.8b shows the opposite case, where m < n, 
where we represent V = [V1; V2], and only compute Vj. 

The cost of computing the SVD is O(min(mn?,m?n)). Details on how it works can be found in 
standard linear algebra textbooks. 


7.5.2 Connection between SVD and EVD 


If A is real, symmetric and positive definite, then the singular values are equal to the eigenvalues, 
and the left and right singular vectors are equal to the eigenvectors (up to a sign change): 


A = USV' = USU' = USU"! (7.179) 


Note, however, that NumPy always returns the singular values in decreasing order, whereas the 
eigenvalues need not necessarily be sorted. 
In general, for an arbitrary real matrix A, if A = USV", we have 


A'A=VS'U' USV' = V(S'S)V" (7.180) 
Hence 
(A'A)V = VD, (7.181) 


so the eigenvectors of ATA are equal to V, the right singular vectors of A, and the eigenvalues of 
ATA are equal to D, = S'S, which is an n x n diagonal matrix containing the squared singular 
values. Similarly 


AA! = USV' VS'U' = U(SS')U" (7.182) 
(AA')U = UD,, (7.183) 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


7.5. Singular value decomposition (SVD) 259 


so the eigenvectors of AAT are equal to U, the left singular vectors of A, and the eigenvalues of 
AA! are equal to Dm = SS', which is an m x m diagonal matrix containing the squared singular 
values. In summary, 


U =evec(AA'), V =evec(A'A), Dm = eval(AA‘),D,, = eval(A' A) (7.184) 
If we just use the computed (non-zero) parts in the economy-sized SVD, then we can define 
D=s’?=s's=ss' (7.185) 


Note also that an EVD does not always exist, even for square A, whereas an SVD always exists. 


7.5.3 Pseudo inverse 


The Moore-Penrose pseudo-inverse of A, pseudo inverse denoted At, is defined as the unique 
matrix that satisfies the following 4 properties: 


AATA=A, AAA = AT, (AAT) = AAT, (AA) = ATA (7.186) 


If A is square and non-singular, then At = Aq. 
If m >n (tall, skinny) and the columns of A are linearly independent (so A is full rank), then 


Al = (A'A) tA! (7.187) 


which is the same expression as arises in the normal equations (see Section 11.2.2.1). In this case, 
At is a left inverse of A because 


AÌA = (A'A) A'A =I (7.188) 
but is not a right inverse because 
AA = A(A'A) A" (7.189) 


only has rank n, and so cannot be the m x m identity matrix. 
If m < n (short, fat) and the rows of A are linearly independent (so AT is full rank), then the 
pseudo inverse is 


At = A! (AA)! (7.190) 


In this case, AŤ is a right inverse of A. 
We can compute the pseudo inverse using the SVD decomposition A = USV". In particular, one 
can show that 


Al = V{diag(1/o1,--- ,1/o,,0,--- ,0);U' = VS-!U" (7.191) 


where r is the rank of the matrix, and where we define S~! = diag(o;",...,071,0,...,0). Indeed if 
the matrices were square and full rank we would have 


(USV')-!=vs-!U" (7.192) 
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7.5.4 SVD and the range and null space of a matrix * 


In this section, we show that the left and right singular vectors form an orthonormal basis for the 
range and null space. 
From Equation (7.178) we have 


Ag = 5 oj(v,x)U; = 5 oj(v,x)u; (7.193) 
j:oj>0 j=1 


where r is the rank of A. Thus any Ag can be written as a linear combination of the left singular 
vectors U, ..., Ur, So the range of A is given by 


range(A) = span ({u; : oj > 0}) (7.194) 


with dimension r. 
To find a basis for the null space, let us now define a second vector y € R” that is a linear 
combination solely of the right singular vectors for the zero singular values, 


y= X gy= J gy (7.195) 


j:0;=0 j=rt+l 


Since the v,;’s are orthonormal, we have 


oly c10 
T 
ayu e" lea 2" =mi (7.196) 
Or+1Ur41Y 0v, 41Y 
onvLy Ovi y 


Hence the right singular vectors form an orthonormal basis for the null space: 

nullspace(A.) = span ({v; : oj = 0}) (7.197) 
with dimension n — r. We see that 

dim(range(A)) + dim(nullspace(A)) =r+(n-—r) =n (7.198) 
In words, this is often written as 

rank + nullity = n (7.199) 


This is called the rank-nullity theorem. It follows from this that the rank of a matrix is the 
number of nonzero singular values. 
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(a) (b) 


rank 5 rank 20 


(c) (a) 


Figure 7.9: Low rank approximations to an image. Top left: The original image is of size 200 x 320, so has 
rank 200. Subsequent images have ranks 2, 5, and 20. Generated by sud_image_demo.ipynb. 


= Original 
4 = Randomized 


Figure 7.10: First 100 log singular values for the clown image (red line), and for a data matrix obtained by 


randomly shuffling the pixels (blue line). Generated by svd_image_demo.ipynb. Adapted from Figure 14.24 
of [HTF09]. 
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7.5.5 Truncated SVD 


Let A = USVT be the SVD of A, and let Ax = UKSKVk, where we use the first K columns of U 
and V. This can be shown to be the optimal rank K approximation, in the sense that it minimizes 
IA- Ax|[3. 

If K = r =rank(A), there is no error introduced by this decomposition. But if K < r, we incur 
some error. This is called a truncated SVD. If the singular values die off quickly, as is typical in 
natural data (see e.g., Figure 7.10), the error will be small. The total number of parameters needed 
to represent an N x D matrix using a rank K approximation is 


NK+KD+K=K(N+D+1) (7.200) 


As an example, consider the 200 x 320 pixel image in Figure 7.9(top left). This has 64,000 numbers 
in it. We see that a rank 20 approximation, with only (200 + 320+ 1) x 20 = 10, 420 numbers is a 
very good approximation. 

One can show that the error in this rank-K approximation is given by 


JA -Alle= So o (7.201) 
k=K+1 


where ox is the k’th singular value of A. 


7.6 Other matrix decompositions * 


In this section, we briefly review some other useful matrix decompositions. 


7.6.1 LU factorization 


We can factorize any square matrix A into a product of a lower triangular matrix L and an upper 
triangular matrix U. For example, 


Q11 G12 413 ty 0 0 Uil Ui? U13 
a2, Q22 Q23| = lo4 loo 0 0 u22 U23ļ|. (7.202) 
a31 432 433 l31 132 l33 0 0 ug 


In general we may need to permute the entries in the matrix before creating this decomposition. 
To see this, suppose a 1; = 0. Since ay, = l11U11, this means either lı or u11 or both must be zero, 
but that would imply L or U are singular. To avoid this, the first step of the algorithm can simply 
reorder the rows so that the first element is nonzero. This is repeated for subsequent steps. We can 
denote this process by 


PA =LU (7.203) 


where P is a permutation matrix, i.e., a square binary matrix where P;; = 1 if row j gets permuted 
to row i. This is called partial pivoting. 
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(a) 


Figure 7.11: Illustration of QR decomposition, A = QR, where Q'Q =I and R is upper triangular. (a) 
Tall, skinny matrix. The shaded parts are not computed in the economy-sized version, since they are not 
needed. (b) Short, wide matriz. 


7.6.2 QR decomposition 


Suppose we have A € R™*” representing a set of linearly independent basis vectors (so m > n), 
and we want to find a series of orthonormal vectors qi, q2,... that span the successive subspaces of 
span(a1), span(aı, a2), etc. In other words, we want to find vectors q; and coefficients r;; such that 


Tii 712 t Tin 

| | | | | | T22 °° Tan 
ai @ `` an| S| 2 `> Qn e (7.204) 

| | | | | | 
Tnn 
We can write this 

a, = 1141 (7.205) 
a2 = r12qı + r22q2 (7.206) 
Qn = Tin 1 a a Tnn|dn (7.207) 


so we see qı spans the space of a;, and qı and q2 span the space of {a1, az}, etc. 
In matrix notation, we have 


A=QR (7.208) 


where Q is m x n with orthonormal columns and R is n x n and upper triangular. This is called a 
reduced QR or economy sized QR factorization of A; see Figure 7.11. 

A full QR factorization appends an additional m — n orthonormal columns to Q so it becomes a 
square, orthogonal matrix Q, which satisfies QQ' = Q'Q = I. Also, we append rows made of zero 
to R so it becomes an m x n matrix that is still upper triangular, called R: see Figure 7.11. The 
zero entries in R “kill off” the new columns in Q, so the result is the same as QR. 

QR decomposition is commonly used to solve systems of linear equations, as we discuss in 
Section 11.2.2.3. 
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7.6.3 Cholesky decomposition 


Any symmetric positive definite matrix can be factorized as A = R'R, where R is upper triangular 
with real, positive diagonal elements. (This can also be written as A = LL’, where L = R! is 
lower triangular.) This is called a Cholesky factorization or matrix square root. In NumPy, 
this is implemented by np.linalg.cholesky. The computational complexity of this operation is 
O(V3), where V is the number of variables, but can be less for sparse matrices. Below we give some 
applications of this factorization. 


7.6.3.1 Application: Sampling from an MVN 


The Cholesky decomposition of a covariance matrix can be used to sample from a multivariate 
Gaussian. Let y ~ N (u, ©) and © = LL. We first sample æ ~ N’(0,1), which is easy because it 
just requires sampling from d separate 1d Gaussians. We then set y = La + pw. This is valid since 


Cov [y] = LCov [x] L! = LI L! = X (7.209) 


See cholesky__demo.ipynb for some code. 


7.7 Solving systems of linear equations * 


An important application of linear algebra is the study of systems of linear equations. For example, 
consider the following set of 3 equations: 


Qa, — 2£2 + 4x3 = —2 (7.211) 
1 
—zı + z727 T3 = 0 (7.212) 


We can represent this in matrix-vector form as follows: 


Ax =b (7.213) 
where 
3 2 —-1 1 
A=|2 —2 4ļ|,b=]ļ|-2 (7.214) 
-1 4 -1 0 


The solution is æ = [1, —2, —2]. 

In general, if we have m equations and n unknowns, then A will be a m x n matrix, and b will be 
am x 1 vector. If m =n (and A is full rank), there is a single unique solution. If m < n, the system 
is underdetermined, so there is not a unique solution. If m > n, the system is overdetermined, 
since there are more constraints than unknowns, and not all the lines intersect at the same point. 
See Figure 7.12 for an illustration. We discuss how to compute solutions in each of these cases below. 
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Figure 7.12: Solution of a set of m linear equations in n = 2 variables. (a) m = 1 < n so the system is 
underdetermined. We show the minimal norm solution as a blue circle. (The dotted red line is orthogonal 
to the line, and its length is the distance to the origin.) (b) m = n = 2, so there is a unique solution. (c) 
m = 3 > n, so there is no unique solution. We show the least squares solution. 


7.7.1 Solving square systems 


In the case where m = n, we can solve for x by computing an LU decomposition, A = LU, and then 
proceeding as follows: 


Aa =b (7.215) 
LUa = b (7.216) 
Uzr=L'bêy (7.217) 
x= Uy (7.218) 


The crucial point is that L and U are both triangular matrices, so we can avoid taking matrix 
inverses, and use a method known as backsubstitution instead. 
In particular, we can solve y = L~'b without taking inverses as follows. First we write 


n yı by 
Lai Loe 
e P=]: (7.219) 
Ln In2 ESN Lnn d bn 
We start by solving L1,y; = bı to find yı and then substitute this in to solve 
La1yı + L22Y2 = be (7.220) 


for y2. We repeat this recursively. This process is often denoted by the backslash operator, 
y =L\b. Once we have y, we can solve x = U~'y using backsubstitution in a similar manner. 
7.7.2 Solving underconstrained systems (least norm estimation) 


In this section, we consider the underconstrained setting, where m < n.? We assume the rows are 
linearly independent, so A is full rank. 


3. Our presentation is based in part on lecture notes by Stephen Boyd at http://ee263.stanford.edu/lectures/ 
min-norm. pdf. 
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When m < n, there are multiple possible solutions, which have the form 
{x: Ax = b} ={a,+2z: z € nullspace(A)} (7.221) 


where Œp is any particular solution. It is standard to pick the particular solution with minimal £2 
norm, i.e., 


ĉ =argmin||a||3 s.t. Ax = b (7.222) 


We can compute the minimal norm solution using the right pseudo inverse: 
Lpinv = A'(AA')~1b (7.223) 


(See Section 7.5.3 for more details.) 
To see this, suppose a is some other solution, so Aw = b, and A(a — £pinv) = 0. Thus 


(£ — pinv)' Lpinv = (L — Lpinv) AT (AAT) tb = (A (£ — £pinv )) (AAT) tb =0 (7.224) 
and hence (Œ — @pinv) L £pinv. By Pythagoras’s theorem, the norm of æ is 
llæl|? = ||Epinv + £ — &pinv ||? = |lEpinv ||? + [|e — Lpinvl|? > |[£pinv]l? (7.225) 


Thus any solution apart from &piny has larger norm. 
We can also solve the constrained optimization problem in Equation (7.222) by minimizing the 
following unconstrained objective 


L(x, A) = aa + A (Ax — b) (7.226) 
From Section 8.5.1, the optimality conditions are 
Vl = 2£ + A'A =0, VaL =Az-b=0 (7.227) 
From the first condition we have « = —ATA/2. Subsituting into the second we get 
1 
Ag = —5AATA =b (7.228) 


which implies A = —2(AA')~'b. Hence 2 = A'(AA')~!b, which is the right pseudo inverse 
solution. 


7.7.3 Solving overconstrained systems (least squares estimation) 


If m > n, we have an overdetermined solution, which typically does not have an exact solution, 
but we will try to find the solution that gets as close as possible to satisfying all of the constraints 
specified by Aw = b. We can do this by minimizing the following cost function, known as the least 
squares objective:* 


f(a) = sllAw — bll (7.238) 


4. Note that some equation numbers have been skipped. This is intentional. The reason is that I have omitted 
some erroneous material from an earlier version (described in https://github.com/probm1/pml-book/issues/266), 
but want to make sure the equation numbering is consistent across different versions of the book. 
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Using matrix calculus results from Section 7.8 we have that the gradient is given by 


g(x) = 5g @) =A'Agx—A'b (7.234) 


x 


The optimum can be found by solving g(a) = 0. This gives 
A'Axv=A'b (7.235) 


These are known as the normal equations, since, at the optimal solution, b — Aa is normal 
(orthogonal) to the range of A, as we explain in Section 11.2.2.2. The corresponding solution & is 
the ordinary least squares (OLS) solution, which is given by 


&—(A'A)'A'D (7.236) 


The quantity At = (A'A)~1AT is the left pseudo inverse of the (non-square) matrix A (see 
Section 7.5.3 for more details). 

We can check that the solution is unique by showing that the Hessian is positive definite. In this 
case, the Hessian is given by 


82? 
~ Ox? 


If A is full rank (so the columns of A are linearly independent), then H is positive definite, since for 
any v > 0, we have 


H(a) f(z) =ATA (7.237) 


v' (A'A)v = (Av)! (Av) = ||Av||? > 0 (7.238) 


Hence in the full rank case, the least squares objective has a unique global minimum. 


7.8 Matrix calculus 


The topic of calculus concerns computing “rates of change” of functions as we vary their inputs. It 
is of vital importance to machine learning, as well as almost every other numerical discipline. In this 
section, we review some standard results. In some cases, we use some concepts and notation from 
matrix algebra, which we cover in Chapter 7. For more details on these results from a deep learning 
perspective, see [PH18]. 


7.8.1 Derivatives 


Consider a scalar-argument function f : R — R. We define its derivative at a point x to be the 
quantity 


(7.239) 


assuming the limit exists. This measures how quickly the output changes when we move a small 
distance in input space away from v (i.e., the “rate of change” of the function). We can interpret 
f'(x) as the slope of the tangent line at f(a), and hence 


fæ +h) ~ f(a) + fi(a)yh (7.240) 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


268 Chapter 7. Linear Algebra 


for small h. 
We can compute a finite difference approximation to the derivative by using a finite step size h, 
as follows: 


flw+h)= f(z) _ ym PEt- f@-h/2) 


/ or is . 
x) = lim = lim 
f ( ) h-0 h h-0 h h-0 h 
———__ 
forward difference central difference backward difference 


The smaller the step size h, the better the estimate, although if h is too small, there can be errors 
due to numerical cancellation. 

We can think of differentiation as an operator that maps functions to functions, D(f) = f’, where 
f'(x) computes the derivative at x (assuming the derivative exists at that point). The use of the 
prime symbol f’ to denote the derivative is called Lagrange notation. The second derivative 
function, which measures how quickly the gradient is changing, is denoted by f”. The mth derivative 
function is denoted f(). 


Alternatively, we can use Leibniz notation, in which we denote the function by y = f(x), and its 
derivative by oy or 4 (x). To denote the evaluation of the derivative at a point a, we write af 


xt=a 


7.8.2 Gradients 


We can extend the notion of derivatives to handle vector-argument functions, f : R” — R, by defining 
the partial derivative of f with respect to x; to be 


Of im fw thei) - fle) (7.242) 


Ox; h-0 h 


where e; is the 7’th unit vector. 
The gradient of a function at a point æ is the vector of its partial derivatives: 


OF 
Ox, 
o 
g= 5h -vj= (7.243) 
x 5f 
OL 
To emphasize the point at which the gradient is evaluated, we can write 
of 
* A 
= 7.244 
g(a") 2 (7.244) 


We see that the operator V (pronounced “nabla”) maps a function f : R” — R to another function 
g : R” > R”. Since g() is a vector-valued function, it is known as a vector field. By contrast, the 
derivative function f’ is a scalar field. 


7.8.3 Directional derivative 


The directional derivative measures how much the function f : R” — R changes along a direction 
v in space. It is defined as follows 


lim : (7.245) 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


7.8. Matrix calculus 269 


We can approximate this numerically using 2 function calls to f, regardless of n. By contrast, a 
numerical approximation to the standard gradient vector takes n + 1 calls (or 2n if using central 
differences). 

Note that the directional derivative along v is the scalar product of the gradient g and the vector 
v: 


D, f(x) =V f(x) v (7.246) 


7.8.4 Total derivative * 


Suppose that some of the arguments to the function depend on each other. Concretely, suppose the 
function has the form f(t, x(t), y(t)). We define the total derivative of f wrt t as follows: 


df Of Ofdx ðf dy 
= .24 
dt Ot Oxdt Oy dt kent) 
If we multiply both sides by the differential dt, we get the total differential 
df = OF ay + of dx + of dy (7.248) 


Ot Ox Oy 


This measures how much f changes when we change t, both via the direct effect of t on f, but also 
indirectly, via the effects of t on x and y. 


7.8.5 Jacobian 


Consider a function that maps a vector to another vector, f : R” — R™. The Jacobian matrix of 
this function is an m x n matrix of partial derivatives: 


J oa oi A : ., : = : 7.249 
f(x) ðxT e i e X T ( ) 
Hi = oad Vim (a) 


Note that we lay out the results in the same orientation as the output f; this is sometimes called 
numerator layout or the Jacobian formulation.” 
7.8.5.1 Multiplying Jacobians and vectors 


The Jacobian vector product or JVP is defined to be the operation that corresponds to right- 
multiplying the Jacobian matrix J € R™*” by a vector v € R”: 


V file)" V filz) v 
Jp(x)v = : v= : (7.250) 


T T 
V fm(x) V fm(£) v 
5. For a much more detailed discussion of notation, see https://en.wikipedia.org/wiki/Matrix_calculus. 
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So we can see that we can approximate this numerically using just 2 calls to f. 
The vector Jacobian product or VJP is defined to be the operation that corresponds to 
left-multiplying the Jacobian matrix J € R’*” by a vector u € R”: 


u'Ss(a) =u" (E, RE) = (u BL, a BE) (7.251) 
The JVP is more efficient if m > n, and the VJP is more efficient if m < n. See Section 13.3 for 


details on how this can be used to perform automatic differentiation in a computation graph such as 
a DNN. 


7.8.5.2 Jacobian of a composition 


Sometimes it is useful to take the Jacobian of the composition of two functions. Let h(a) = g(f(a)). 
By the chain rule of calculus, we have 


Jn(x) = Iq(f(a)) J p (2) (7.252) 


For example, suppose f : R > R? and g : R? > R?. We have 


ag (gah ay Sor fa + $n Sh 
= =| ¥ = | sa? 2 Oe 7.253 
ðr T \ Foo fila), fala))) T | Se She + 2 oh uae 
-EDE 
of" Ox (iB Bhs sh 


7.8.6 Hessian 


For a function f : R” — R that is twice differentiable, we define the Hessian matrix as the 
(symmetric) n x n matrix of second partial derivatives: 


OF es Of 
Ou? O21 02n 
0? ; 
Hy; = of =V’f= : (7.255) 
Ox? : : g 
sie of pis gf 
Ln OLY, oe 


We see that the Hessian is the Jacobian of the gradient. 


7.8.7 Gradients of commonly used functions 


In this section, we list without proof the gradients of certain widely used functions. 


7.8.7.1 Functions that map scalars to scalars 


Consider a differentiable function f : R —> R. Here are some useful identities from scalar calculus, 
which you should already be familiar with. 
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Loz” = cng”! (7.256) 
4 log(x) = 1/x (7.257) 
d 
T exp(x) = exp(x) (7.258) 
d df (x dg(x 
Z ratge = FO 4 BO) (7.259) 
dx dx dx 
d dg() df (x) 
pace = | 2 
E [ela )o(a)] = Fa) 2 + g(a) (7.260) 
d du df (u) 
ae = 261 
{jua = ZA (7.261) 
Equation (7.261) is known as the chain rule of calculus. 
7.8.7.2 Functions that map vectors to scalars 
Consider a differentiable function f : R” — R. Here are some useful identities:° 
alal x) 
ag =a (7.262) 
O(b' Ax) = 
=A b 2 
Dn (7.263) 
= 
Oe Ae) =(A+A")x (7.264) 


It is fairly easy to prove these identities by expanding out the quadratic form, and applying scalar 
calculus. 


7.8.7.3 Functions that map matrices to scalars 


Consider a function f : R”*” — R which maps a matrix to a scalar. We are using the following 
natural layout for the derivative matrix: 


of... of 
0x11 Xin 
o 
> a _ (7.265) 
of... „OF 
OS ont. OLmn 


Below are some useful identities. 


6. Some of the identities are taken from the list at http: //www.cs.nyu.edu/~roweis/notes/matrixid. pdf. 
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Identities involving quadratic forms 


One can show the following results. 


S(a"Xb) = ab' (7.266) 
SZ axo) = ba' (7.267) 


Identities involving matrix trace 


One can show the following results. 


ð — ATpT 
zx" (AXB) = A'B (7.268) 
a ate 
ax i(X'A) =A (7.269) 
0 -1 = —TaTy-—T 
ax (1A) = -XTATX (7.270) 
Str(XTAX) =(A+A"')X (7.271) 


Identities involving matrix determinant 


One can show the following results. 


x det(AXB) = det(AXB)X T! (7.272) 
ð 2 et 
ax log(det(X)) = X (7.273) 


7.9 Exercises 


Exercise 7.1 [Orthogonal matrices] 
a. A rotation in 3d by angle a about the z axis is given by the following matrix: 


cos(a) —sin(a) 0 
R(a) = | sin(a) cos(a) 0 (7.274) 
0 0 1 


Prove that R is an orthogonal matrix, i.e., RTR = I, for any a. 


b. What is the only eigenvector v of R with an eigenvalue of 1.0 and of unit norm (i.e., ||v||? = 1)? (Your 
answer should be the same for any a.) Hint: think about the geometrical interpretation of eigenvectors. 


Exercise 7.2 [Eigenvectors by hand *] 
Find the eigenvalues and eigenvectors of the following matrix 


jz € 3) (7.275) 


Compute your result by hand and check it with Python. 
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3 Optimization 


Parts of this chapter were written by Frederik Kunstner, Si Yi Meng, Aaron Mishkin, Sharan Vaswani, 
and Mark Schmidt. 


8.1 Introduction 


We saw in Chapter 4 that the core problem in machine learning is parameter estimation (aka model 
fitting). This requires solving an optimization problem, where we try to find the values for a set 
of variables 0 € O, that minimize a scalar-valued loss function or cost function £ : © > R: 


0* € argmin £(0) (8.1) 
Bco 


We will assume that the parameter space is given by © C R?, where D is the number of variables 
being optimized over. Thus we are focusing on continuous optimization, rather than discrete 
optimization. 

If we want to maximize a score function or reward function R(@), we can equivalently minimize 
L(0) = —R(0). We will use the term objective function to refer generically to a function we want 
to maximize or minimize. An algorithm that can find an optimum of an objective function is often 
called a solver. 

In the rest of this chapter, we discuss different kinds of solvers for different kinds of objective 
functions, with a focus on methods used in the machine learning community. For more details on 
optimization, please consult some of the many excellent textbooks, such as [KW19b; BV04; NW06; 
Ber15; Ber16] as well as various review articles, such as [BCN18; Sun+19b; PPS18; Pey20]. 


8.1.1 Local vs global optimization 


A point that satisfies Equation (8.1) is called a global optimum. Finding such a point is called 
global optimization. 

In general, finding global optima is computationally intractable [Neu04]. In such cases, we will 
just try to find a local optimum. For continuous problems, this is defined to be a point 0* which 
has lower (or equal) cost than “nearby” points. Formally, we say 0* is a local minimum if 


WwW 


5>0,VOEO st. ||9—0*|| <5, L(0*) < L(A) (8.2) 
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local minimum 


-2.5 Global minimum 


-1.0 -05 0.0 0.5 1.0 15 2.0 2.5 3.0 
(a) (b) 


Figure 8.1: (a) Illustration of local and global minimum in 1d. Generated by extrema_ fig _ 1d.ipynb. (b) 
Illustration of a saddle point in 2d. Generated by saddle.ipynb. 


A local minimum could be surrounded by other local minima with the same objective value; this 
is known as a flat local minimum. A point is said to be a strict local minimum if its cost is 
strictly lower than those of neighboring points: 


5>0, VO € 0,040": ||0 — 6*|| < 8, L(O*) < L(8) (8.3) 


We can define a (strict) local maximum analogously. See Figure 8.1a for an illustration. 

A final note on terminology; if an algorithm is guaranteed to converge to a stationary point 
from any starting point, it is called globally convergent. However, this does not mean (rather 
confusingly) that it will converge to a global optimum; instead, it just means it will converge to some 
stationary point. 


8.1.1.1 Optimality conditions for local vs global optima 


For continuous, twice differentiable functions, we can precisely characterize the points which cor- 
respond to local minima. Let g(@) = VL(0) be the gradient vector, and H(@) = V?L(0) be the 
Hessian matrix. (See Section 7.8 for a refresher on these concepts, if necessary.) Consider a point 
0* € RP, and let g* = g(0)|ọ» be the gradient at that point, and H* = H(6)|9~ be the corresponding 
Hessian. One can show that the following conditions characterize every local minimum: 


e Necessary condition: If 6* is a local minimum, then we must have g* = 0 (ie., 0* must be a 
stationary point), and H* must be positive semi-definite. 


e Sufficient condition: If g* = 0 and H* is positive definite, then 6* is a local optimum. 
To see why the first condition is necessary, suppose we were at a point 0* at which the gradient is 


non-zero: at such a point, we could decrease the function by following the negative gradient a small 
distance, so this would not be optimal. So the gradient must be zero. (In the case of nonsmooth 
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functions, the necessary condition is that the zero is a local subgradient at the minimum.) To see why 
a zero gradient is not sufficient, note that the stationary point could be a local minimum, maximum 
or saddle point, which is a point where some directions point downhill, and some uphill (see 
Figure 8.1b). More precisely, at a saddle point, the eigenvalues of the Hessian will be both positive 
and negative. However, if the Hessian at a point is positive semi-definite, then some directions may 
point uphill, while others are flat. Moreover, if the Hessian is strictly positive definite, then we are at 
the bottom of a “bowl”, and all directions point uphill, which is sufficient for this to be a minimum. 


8.1.2 Constrained vs unconstrained optimization 


In unconstrained optimization, we define the optimization task as finding any value in the 
parameter space © that minimizes the loss. However, we often have a set of constraints on the 
allowable values. It is standard to partition the set of constraints C into inequality constraints, 
g;(@) < 0 for j € Z and equality constraints, h,(@) = 0 for k € E. For example, we can represent a 
sum-to-one constraint as an equality constraint h(0) = (1— oo 1 8i) = 0, and we can represent a non- 
negativity constraint on the parameters by using D inequality constraints of the form g;(@) = —6; < 0 


We define the feasible set as the subset of the parameter space that satisfies the constraints: 


C= {0:9;(0) <0:7 ET, hk(0)=0: k CE} CR? (8.4) 
Our constrained optimization problem now becomes 
6* € argmin £(6) (8.5) 
Oec 


If C = RP, it is called unconstrained optimization. 

The addition of constraints can change the number of optima of a function. For example, a function 
that was previously unbounded (and hence had no well-defined global maximum or minimum) can 
“acquire” multiple maxima or minima when we add constraints, as illustrated in Figure 8.2. However, 
if we add too many constraints, we may find that the feasible set becomes empty. The task of finding 
any point (regardless of its cost) in the feasible set is called a feasibility problem; this can be a 
hard subproblem in itself. 

A common strategy for solving constrained problems is to create penalty terms that measure 
how much we violate each constraint. We then add these terms to the objective and solve an 
unconstrained optimization problem. The Lagrangian is a special case of such a combined objective 
(see Section 8.5 for details). 


8.1.3 Convex vs nonconvex optimization 


In convex optimization, we require the objective to be a convex function defined over a convex 
set (we define these terms below). In such problems, every local minimum is also a global minimum. 
Thus many models are designed so that their training objectives are convex. 


8.1.3.1 Convex sets 


We say S is a convex set if, for any æ, x’ € S, we have 


de +(1—A)a#’ ES, VAE [0,1] (8.6) 
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Unique global maximum 2 global maxima no global maxima 


i 


Figure 8.2: Illustration of constrained maximization of a nonconvex 1d function. The area between the dotted 
vertical lines represents the feasible set. (a) There is a unique global maximum since the function is concave 
within the support of the feasible set. (b) There are two global maxima, both occuring at the boundary of the 
feasible set. (c) In the unconstrained case, this function has no global maximum, since it is unbounded. 


Convex Not Convex 


Figure 8.3: Illustration of some convex and non-convex sets. 


That is, if we draw a line from x to 2’, all points on the line lie inside the set. See Figure 8.3 for 
some illustrations of convex and non-convex sets. 


8.1.3.2 Convex functions 


We say f is a convex function if its epigraph (the set of points above the function, illustrated in 
Figure 8.4a) defines a convex set. Equivalently, a function f(a) is called convex if it is defined on a 
convex set and if, for any Œz, y € S, and for any 0 < A < 1, we have 


fw + (1—Ajy) < Af) + (1 — AVF) (8.7) 


See Figure 8.5(a) for a 1d example of a convex function. A function is called strictly convex if the 
inequality is strict. A function f(a) is concave if — f(a) is convex, and strictly concave if — f(x) 
is strictly convex. See Figure 8.5(b) for a 1d example of a function that is neither convex nor concave. 
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(a) (b) 


Figure 8.4: (a) Illustration of the epigraph of a function. (b) For a convex function f(x), its epipgraph can 
be represented as the intersection of half-spaces defined by linear lower bounds derived from the conjugate 
function f*(A) = maxz Ax — f(x). 


(a) (b) 


Figure 8.5: (a) Illustration of a convex function. We see that the chord joining (x, f(x)) to (y, f(y)) lies 
above the function. (b) A function that is neither convex nor concave. A is a local minimum, B is a global 
minimum. 


Here are some examples of 1d convex functions: 


ax 


— log x 

r’, a>l,xz>0 
|x|", a>1 
xlogz, x>0 


8.1.3.3 Characterization of convex functions 


Intuitively, a convex function is shaped like a bowl. Formally, one can prove the following important 
result: 
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Figure 8.6: The quadratic form f(x) = 2’ Aa in 2d. (a) A is positive definite, so f is convex. (b) A is 
negative definite, so f is concave. (c) A is positive semidefinite but singular, so f is convex, but not strictly. 
Notice the valley of constant height in the middle. (d) A is indefinite, so f is neither convex nor concave. 
The stationary point in the middle of the surface is a saddle point. From Figure 5 of [She94]. 


Theorem 8.1.1. Suppose f : R” — R is twice differentiable over its domain. Then f is convex iff 
H = V? f(x) is positive semi definite (Section 7.1.5.3) for alla € dom(f). Furthermore, f is strictly 
conver if H is positive definite. 


For example, consider the quadratic form 

f(a) =a' Ax (8.8) 
This is convex if A is positive semi definite, and is strictly convex if A is positive definite. It is 
neither convex nor concave if A has eigenvalues of mixed sign. See Figure 8.6. 


8.1.3.4 Strongly convex functions 


We say a function f is strongly convex with parameter m > 0 if the following holds for all x, y in 
f’s domain: 


(V(x) — VE(y))"(@— y) = mila — yll (8.9) 


A strongly convex function is also strictly convex, but not vice versa. 

If the function f is twice continuously differentiable, then it is strongly convex with parameter m 
if and only if V? f(a) = mI for all x in the domain, where I is the identity and V? f is the Hessian 
matrix, and the inequality = means that V? f(a) — mI is positive semi-definite. This is equivalent 
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Smooth function Non-smooth function 


0.0 0.0 
-1.00 -0.75 -0.50 -0.25 0.00 0.25 050 0.75 1.00 -1.00 -0.75 -0.50 -0.25 0.00 0.25 0.50 0.75 1.00 


(a) (b) 


Figure 8.7: (a) Smooth 1d function. (b) Non-smooth 1d function. (There is a discontinuity at the origin.) 
Generated by smooth-vs-nonsmooth-1d.ipynb. 


to requiring that the minimum eigenvalue of V? f (æ) be at least m for all æ. If the domain is just 
the real line, then V? f(a) is just the second derivative f’’(x), so the condition becomes f”(x) > m. 
If m = 0, then this means the Hessian is positive semidefinite (or if the domain is the real line, it 
means that f”(x) > 0), which implies the function is convex, and perhaps strictly convex, but not 
strongly convex. 

The distinction between convex, strictly convex, and strongly convex is rather subtle. To better 
understand this, consider the case where f is twice continuously differentiable and the domain is the 
real line. Then we can characterize the differences as follows: 

e f is convex if and only if f” (x) > 0 for all x. 
e f is strictly convex if f”(x) > 0 for all x (note: this is sufficient, but not necessary). 
e f is strongly convex if and only if f” (x) > m > 0 for all z. 
Note that it can be shown that a function f is strongly convex with parameter m iff the function 


I(x) = f(x) — Fel? (8.10) 


is convex. 


8.1.4 Smooth vs nonsmooth optimization 


In smooth optimization, the objective and constraints are continuously differentiable functions. 
For smooth functions, we can quantify the degree of smoothness using the Lipschitz constant. In 
the 1d case, this is defined as any constant L > 0 such that, for all real x; and x2, we have 


|f (v1) — f(z2)| < Liar — z2| (8.11) 


This is illustrated in Figure 8.8: for a given constant L, the function output cannot change by more 
than L if we change the function input by 1 unit. This can be generalized to vector inputs using a 
suitable norm. 

In nonsmooth optimization, there are at least some points where the gradient of the objective 
function or the constraints is not well-defined. See Figure 8.7 for an example. In some optimization 
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Figure 8.8: For a Lipschitz continuous function f, there exists a double cone (white) whose origin can be 
moved along the graph of f so that the whole graph always stays outside the double cone. From https: //en. 
wikipedia. org/wiki/Lipschitz_ continuity. Used with kind permission of Wikipedia author Taschee. 


problems, we can partition the objective into a part that only contains smooth terms, and a part 
that contains the nonsmooth terms: 


L£(0) = £,(0) + £,(0) (8.12) 


where £, is smooth (differentiable), and £, is nonsmooth (“rough”). This is often referred to as a 
composite objective. In machine learning applications, £s is usually the training set loss, and 
Lr is a regularizer, such as the 44 norm of 0. This composite structure can be exploited by various 
algorithms. 


8.1.4.1 Subgradients 


In this section, we generalize the notion of a derivative to work with functions which have local 
discontinuities. In particular, for a convex function of several variables, f : R” — R, we say that 
g € R” is a subgradient of f at x € dom(f) if for all z € dom(f), 


f(z) > f(z) +9" (2-2) (8.13) 


Note that a subgradient can exist even when f is not differentiable at a point, as shown in Figure 8.9. 
A function f is called subdifferentiable at x if there is at least one subgradient at æ. The set of 
such subgradients is called the subdifferential of f at x, and is denoted Of (x). 
For example, consider the absolute value function f(x) = |x|. Its subdifferential is given by 


{-1} ifa#<0 
af(z)=% [-1,1] if2=0 (8.14) 
{+1} ifa>0 


where the notation [—1,1] means any value between -1 and 1 inclusive. See Figure 8.10 for an 
illustration. 


8.2 First-order methods 


In this section, we consider iterative optimization methods that leverage first-order derivatives of 
the objective function, i.e., they compute which directions point “downhill”, but they ignore curvature 
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f(x) 


f(w1) + gi (2 — zi). 


_F (a2) + 93 (x — 22) 
~ f (22) + 93 (£ — z2) 


L1 g Lo 


Figure 8.9: Illustration of subgradients. At xı, the convex function f is differentiable, and gi (which is 
the derivative of f at xı) is the unique subgradient at xı. At the point x2, f is not differentiable, because 
of the “kink”. However, there are many subgradients at this point, of which two are shown. From https: 
//web. stanford. edu/class/ ee364b/ Lectures/ subgradients_ slides. pdf. Used with kind permission 
of Stephen Boyd. 


f(x) = |x| Of (x) 


Figure 8.10: The absolute value function (left) and its subdifferential (right). From https: //web. stanford. 
edu/ class/ ee364b/ Lectures/ subgradients_ slides. pdf. Used with kind permission of Stephen Boyd. 


information. All of these algorithms require that the user specify a starting point 09. Then at each 
iteration t, they perform an update of the following form: 


Oiz = 0: + Medi (8.15) 
where 7 is known as the step size or learning rate, and d; is a descent direction, such as the 
negative of the gradient, given by g: = VoLl(@)|9,. These update steps are continued until the 


method reaches a stationary point, where the gradient is zero. 
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8.2.1 Descent direction 


We say that a direction d is a descent direction if there is a small enough (but nonzero) amount n 
we can move in direction d and be guaranteed to decrease the function value. Formally, we require 
that there exists an Nmax > 0 such that 


L(0 + nd) < L(8) (8.16) 
for all 0 < N < Nmax. The gradient at the current iterate, 01, is given by 
g: = VL(9)\o, = VL(8:) = g(A) (8.17) 


This points in the direction of maximal increase in f, so the negative gradient is a descent direction. 
It can be shown that any direction d is also a descent direction if the angle 0 between d and —g; is 
less than 90 degrees and satisfies 


d' gi = ||dl| ||gz|| cos(#) < 0 (8.18) 
It seems that the best choice would be to pick d = —g;. This is known as the direction of steepest 
descent. However, this can be quite slow. We consider faster versions later. 
8.2.2 Step size (learning rate) 


In machine learning, the sequence of step sizes {m+} is called the learning rate schedule. There are 
several widely used methods for picking this, some of which we discuss below. (See also Section 8.4.3, 
where we discuss schedules for stochastic optimization.) 


8.2.2.1 Constant step size 


The simplest method is to use a constant step size, 7 = 7. However, if it is too large, the method 
may fail to converge, and if it is too small, the method will converge but very slowly. 
For example, consider the convex function 


L(0) = 0.5(07 — 62)? + 0.5(01 — 1)” (8.19) 


Let us pick as our descent direction d; = —g;. Figure 8.11 shows what happens if we use this descent 
direction with a fixed step size, starting from (0,0). In Figure 8.11(a), we use a small step size of 
n = 0.1; we see that the iterates move slowly along the valley. In Figure 8.11(b), we use a larger step 
size 7 = 0.6; we see that the iterates start oscillating up and down the sides of the valley and never 
converge to the optimum, even though this is a convex problem. 

In some cases, we can derive a theoretical upper bound on the maximum step size we can use. For 
example, consider a quadratic objective, £(@) = 30'AO +b'@+4cwith A > 0. One can show that 
steepest descent will have global convergence iff the step size satisfies 


2 


1S aA 


(8.20) 
where Amax(A) is the largest eigenvalue of A. The intuitive reason for this can be understood by 
thinking of a ball rolling down a valley. We want to make sure it doesn’t take a step that is larger than 


the slope of the steepest direction, which is what the largest eigenvalue measures (see Section 3.2.2). 
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step size 0.100 step size 0.600 


(a) (b) 


Figure 8.11: Steepest descent on a simple convex function, starting from (0,0), for 20 steps, using a fixed step 
size. The global minimum is at (1,1). (a) n= 0.1. (b) n =0.6. Generated by steepestDescentDemo.ipynb. 


More generally, setting 7 < 2/L, where L is the Lipschitz constant of the gradient (Section 8.1.4), 
ensures convergence. Since this constant is generally unknown, we usually need to adapt the step 
size, as we discuss below. 


8.2.2.2 Line search 


The optimal step size can be found by finding the value that maximally decreases the objective along 
the chosen direction by solving the 1d minimization problem 


m = argmin ¢;(7) = argmin L(A; + ndi) (8.21) 
n>0 n>0 


This is known as line search, since we are searching along the line defined by dz. 
If the loss is convex, this subproblem is also convex, because ¢:(7) = £(0; + ndi) is a convex 
function of an affine function of 7, for fixed @; and d. For example, consider the quadratic loss 


L(0) = 50" +b'0+¢ (8.22) 


Computing the derivative of ¢ gives 


d d f1 

OO)". | Peg ena An Oa (8.23) 

dn dn |2 
= d' A (0 + nd) + d'b (8.24) 
= d' (A0 +b) + nd' Ad (8.25) 

Solving for aet) = 0 gives 

d"(A@ +b) 

2 5 
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Using the optimal step size is known as exact line search. However, it is not usually necessary 
to be so precise. There are several methods, such as the Armijo backtracking method, that try 
to ensure sufficient reduction in the objective function without spending too much time trying to 
solve Equation (8.21). In particular, we can start with the current stepsize (or some maximum value), 
and then reduce it by a factor 0 < c < 1 at each step until we satisfy the following condition, known 
as the Armijo-Goldstein test: 


L(O, + nd,) < L(0+) + end} VL(A:) (8.27) 


where c € [0,1] is a constant, typically c = 1074. In practice, the initialization of the line-search and 
how to backtrack can significantly affect performance. See [NW06, Sec 3.1] for details. 


8.2.3 Convergence rates 


We want to find optimization algorithms that converge quickly to a (local) optimum. For certain 
convex problems, with a gradient with bounded Lipschitz constant, one can show that gradient 
descent converges at a linear rate. This means that there exists a number 0 < u < 1 such that 


[L(A:41) — L(04)| < wIL(Ar) — £(A.)| (8.28) 


Here yp is called the rate of convergence. 

For some simple problems, we can derive the convergence rate explicitly, For example, consider a 
quadratic objective £L(0) = 50'AO +6'@+cwith A > 0. Suppose we use steepest descent with 
exact line search. One can show (see e.g., [Ber15]) that the convergence rate is given by 


Amax = Amin i 
p= (Ge) 22) 


where Amax is the largest eigenvalue of A and Amin is the smallest eigenvalue. We can rewrite this 
as u = (3), where k = Sas is the condition number of A. Intuitively, the condition number 
measures how “skewed” the space is, in the sense of being far from a symmetrical “bowl”. (See 
Section 7.1.4.4 for more information on condition numbers.) 

Figure 8.12 illustrates the effect of the condition number on the convergence rate. On the left 
we show an example where A = [20,5;5, 2], b = [—14; —6] and c = 10, so K(A) = 30.234. On the 
right we show an example where A = [20,5;5, 16], b = [—14; —6] and c = 10, so K(A) = 1.8541. We 
see that steepest descent converges much more quickly for the problem with the smaller condition 
number. 

In the more general case of non-quadratic functions, the objective will often be locally quadratic 
around a local optimum. Hence the convergence rate depends on the condition number of the Hessian, 
«(H), at that point. We can often improve the convergence speed by optimizing a surrogate objective 
(or model) at each step which has a Hessian that is close to the Hessian of the objective function as 
we discuss in Section 8.3. 

Although line search works well, we see from Figure 8.12 that the path of steepest descent with an 
exact line-search exhibits a characteristic zig-zag behavior, which is inefficient. This problem can be 
overcome using a method called conjugate gradient descent (see e.g., [She94]). 
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condition number of A=30.234 condition number of A=1.854 


Figure 8.12: Illustration of the effect of condition number k on the convergence speed of steepest descent with 
exact line searches. (a) Large k. (b) Small K. Generated by lineSearchConditionNum.ipynb. 


8.2.4 Momentum methods 


Gradient descent can move very slowly along flat regions of the loss landscape, as we illustrated in 
Figure 8.11. We discuss some solutions to this below. 


8.2.4.1 Momentum 


One simple heuristic, known as the heavy ball or momentum method [Ber99], is to move faster 
along directions that were previously good, and to slow down along directions where the gradient has 
suddenly changed, just like a ball rolling downhill. This can be implemented as follows: 


mM = BM + Gt-1 (8.30) 
0i = 81-1 — mm (8.31) 


where m; is the momentum (mass times velocity) and 0 < 8 < 1. A typical value of 8 is 0.9. For 
8 = 0, the method reduces to gradient descent. 

We see that m is like an exponentially weighted moving average of the past gradients (see 
Section 4.4.2.2): 


t-1 
mz = bm- + gi-1 = B’ Mi + Bg- + 9-1 = + = 5 B gt-r-1 (8.32) 
T=0 
If all the past gradients are a constant, say g, this simplifies to 
t-1 
m=9 5-8" (8.33) 
T=0 
The scaling factor is a geometric series, whose infinite sum is given by 
ae 1 
1 2 TIE a j 4 
+8+8 + me a (8.34) 
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Figure 8.18: Illustration of the Nesterov update. Adapted from Figure 11.6 of [Gér19]. 


Thus in the limit, we multiply the gradient by 1/(1 — 8). For example, if 8 = 0.9, we scale the 
gradient up by 10. 

Since we update the parameters using the gradient average m;_1, rather than just the most recent 
gradient, g:-1, we see that past gradients can exhibit some influence on the present. Furthermore, 
when momentum is combined with SGD, discussed in Section 8.4, we will see that it can simulate 
the effects of a larger minibatch, without the computational cost. 


8.2.4.2 Nesterov momentum 


One problem with the standard momentum method is that it may not slow down enough at the 
bottom of a valley, causing oscillation. The Nesterov accelerated gradient method of [Nes04] 
instead modifies the gradient descent to include an extrapolation step, as follows: 


6141 = O; + B(O; — Oi) (8.35) 
O444 = Oii = mV L(O141) (8.36) 


This is essentially a form of one-step “look ahead”, that can reduce the amount of oscillation, as 
illustrated in Figure 8.13. 

Nesterov accelerated gradient can also be rewritten in the same format as standard momentum. In 
this case, the momentum term is updated using the gradient at the predicted new location, 


Mipi = Bm, — MVL(A + Bm) (8.37) 
O44 = O, + Mti (8.38) 


This explains why the Nesterov accelerated gradient method is sometimes called Nesterov momentum. 
It also shows how this method can be faster than standard momentum: the momentum vector 
is already roughly pointing in the right direction, so measuring the gradient at the new location, 
6, + Bm, rather than the current location, 0+, can be more accurate. 

The Nesterov accelerated gradient method is provably faster than steepest descent for convex 
functions when 8 and m are chosen appropriately. It is called “accelerated” because of this improved 
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convergence rate, which is optimal for gradient-based methods using only first-order information 
when the objective function is convex and has Lipschitz-continuous gradients. In practice, however, 
using Nesterov momentum can be slower than steepest descent, and can even unstable if 8 or m are 
misspecified. 


8.3 Second-order methods 


Optimization algorithms that only use the gradient are called first-order methods. They have the 
advantage that the gradient is cheap to compute and to store, but they do not model the curvature 
of the space, and hence they can be slow to converge, as we have seen in Figure 8.12. Second-order 
optimization methods incorporate curvature in various ways (e.g., via the Hessian), which may yield 
faster convergence. We discuss some of these methods below. 


8.3.1 Newton’s method 


The classic second-order method is Newton’s method. This consists of updates of the form 


A441 =O, — mH; gi (8.39) 
where 
H; £ V°L(6)\o, = V? L(0:) = H (0+) (8.40) 


is assumed to be positive-definite to ensure the update is well-defined. The pseudo-code for Newton’s 
method is given in Algorithm 8.1. The intuition for why this is faster than gradient descent is that the 
matrix inverse H~! “undoes” any skew in the local curvature, converting a topology like Figure 8.12a 
to one like Figure 8.12b. 


Algorithm 8.1: Newton’s method for minimizing a function 


1 Initialize 09 

2 for t=1,2,... until convergence do 
Evaluate g; = VL(8:) 

Evaluate H; = V7L(0;) 

Solve H;d; = —g; for di 

Use line search to find stepsize 7, along d 
O41 = Oe + md: 


No a BP Ww 


This algorithm can be derived as follows. Consider making a second-order Taylor series approxi- 
mation of £(@) around @: 


Laua (8) = £(8,) + 9f (0 — 01) + Z (0 — 8)" H.(8 — 6.) (8.41) 


The minimum of Lquaa is at 


0 = 0, — H; 'g: (8.42) 
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(a) (b) 


Figure 8.14: Illustration of Newton’s method for minimizing a 1d function. (a) The solid curve is the 
function L(x). The dotted line Lauaa() is its second order approximation at 0+. The Newton step d is what 
must be added to 0; to get to the minimum of Lquaa (0). Adapted from Figure 13.4 of [Van06]. Generated 
by newtonsMethodMinQuad.ipynb. (b) Illustration of Newton’s method applied to a nonconvex function. 
We fit a quadratic function around the current point 0; and move to its stationary point, 0141 = Ot + di. 
Unfortunately, this takes us near a local maximum of f, not minimum. This means we need to be careful 
about the extent of our quadratic approximation. Adapted from Figure 13.11 of [Van06]. Generated by 
newtonsMethodNonConvex.ipynb. 


So if the quadratic approximation is a good one, we should pick d; = —H, 1g: as our descent direction. 
See Figure 8.14(a) for an illustration. Note that, in a “pure” Newton method, we use 7 = 1 as our 
stepsize. However, we can also use linesearch to find the best stepsize; this tends to be more robust 
as using 7, = 1 may not always converge globally. 

If we apply this method to linear regression, we get to the optimum in one step, since (as we show 
in Section 11.2.2.1) we have H = X'X and g = X'Xw — X'y, so the Newton update becomes 


Ww, = Wo — Ho! g = wo — (XTX)! (XX wo — X'y) = wo — wo + (XX)! X'y (8.43) 
which is the OLS estimate. However, when we apply this method to logistic regression, it may take 
multiple iterations to converge to the global optimum, as we discuss in Section 10.2.6. 

8.3.2 BFGS and other quasi-Newton methods 


Quasi-Newton methods, sometimes called variable metric methods, iteratively build up an 
approximation to the Hessian using information gleaned from the gradient vector at each step. The 
most common method is called BFGS (named after its simultaneous inventors, Broyden, Fletcher, 
Goldfarb and Shanno), which updates the approximation to the Hessian B; ~ H; as follows: 


yyl (Bist) (B:s+)! 


B =B 8.44 
= - yl St sl Bis: ( ) 
St = O, one Oi—ı (8.45) 

Yt = Gt — Gt-1 (8.46) 


This is a rank-two update to the matrix. If Bo is positive-definite, and the step size 7 is chosen 
via line search satisfying both the Armijo condition in Equation (8.27) and the following curvature 
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Figure 8.15: Illustration of the trust region approach. The dashed lines represents contours of the original 
nonconvex objective. The circles represent successive quadratic approximations. From Figure 4.2 of [Pas14]. 
Used with kind permission of Razvan Pascanu. 


condition 
VL(O, + ndi) > cond} VL(O;) (8.47) 


then B,,, will remain positive definite. The constant cz is chosen within (c,1) where c is the 
tunable parameter in Equation (8.27). The two step size conditions are together known as the Wolfe 
conditions. We typically start with a diagonal approximation, By = I. Thus BFGS can be thought 
of as a “diagonal plus low-rank” approximation to the Hessian. 

Alternatively, BFGS can iteratively update an approximation to the inverse Hessian, C; ~ H;", 
as follows: 


T T y 
Cr = (x aoe ) C (x ust i (8.48) 


T T 
Yı St Yı St Yı St 


Since storing the Hessian approximation still takes O(D?) space, for very large problems, one 
can use limited memory BFGS, or L-BFGS, where we control the rank of the approximation by 
only using the M most recent (s+, y+) pairs while ignoring older information. Rather than storing 
B, explicitly, we just store these vectors in memory, and then approximate H7 *g, by performing a 
sequence of inner products with the stored s; and y; vectors. The storage requirements are therefore 
O(M D). Typically choosing M to be between 5-20 suffices for good performance [NW06, p177]. 

Note that sklearn uses LBFGS as its default solver for logistic regression.' 


8.3.3 Trust region methods 


If the objective function is nonconvex, then the Hessian H; may not be positive definite, so d; = 
—H;'g; may not be a descent direction. This is illustrated in 1d in Figure 8.14(b), which shows 
that Newton’s method can end up in a local maximum rather than a local minimum. 

In general, any time the quadratic approximation made by Newton’s method becomes invalid, we 
are in trouble. However, there is usually a local region around the current iterate where we can safely 


1. See https: //scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html. 
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approximate the objective by a quadratic. Let us call this region R+, and let us call M (8) the model 
(or approximation) to the objective, where 6 = 0 — 6;. Then at each step we can solve 
ô“ = argmin M;(6) (8.49) 
JER: 
This is called trust-region optimization. (This can be seen as the “opposite” of line search, in the 
sense that we pick a distance we want to travel, determined by R+, and then solve for the optimal 


direction, rather than picking the direction and then solving for the optimal distance.) 
We usually assume that M;(6) is a quadratic approximation: 


i 
M,(6) = L(0:) + gp 6 + 50 Hið (8.50) 


where g; = VoL(8)|o, is the gradient, and H; = VZL(9)|o, is the Hessian. Furthermore, it is common 
to assume that R+ is a ball of radius r, i.e., Re = {ô : ||ő||2 < r}. Using this, we can convert the 
constrained problem into an unconstrained one as follows: 


1 
6* = argmin M (8) + Aj|6||3 = argmin g' ô + 5° (E + \N)6 (8.51) 
ô ô 
for some Lagrange multiplier A > 0 which depends on the radius r (see Section 8.5.1 for a discussion 
of Lagrange multipliers). We can solve this using 
6=—-(H+Al)"'g (8.52) 


This is called Tikhonov damping or Tikhonov regularization. See Figure 8.15 for an illustration. 

Note that adding a sufficiently large AI to H ensures the resulting matrix is always positive definite. 
As à + 0, this trust method reduces to Newton’s method, but for A large enough, it will make all 
the negative eigenvalues positive (and all the 0 eigenvalues become equal to A). 


8.4 Stochastic gradient descent 


In this section, we consider stochastic optimization, where the goal is to minimize the average 
value of a function: 


where z is arandom input to the objective. This could be a “noise” term, coming from the environment, 
or it could be a training example drawn randomly from the training set, as we explain below. 

At each iteration, we assume we observe £;(0) = £(0, zt), where z; ~ q. We also assume a way to 
compute an unbiased estimate of the gradient of £. If the distribution g(z) is independent of the 
parameters we are optimizing, we can use g = VoLl;(9;). In this case, the resulting algorithm can 
be written as follows: 


O41 =O, — mVL(O, Zt) = Q — mg (8.54) 


This method is known as stochastic gradient descent or SGD. As long as the gradient estimate 
is unbiased, then this method will converge to a stationary point, providing we decay the step size m 
at a certain rate, as we discuss in Section 8.4.3. 
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8.4.1 Application to finite sum problems 


SGD is very widely used in machine learning. To see why, recall from Section 4.3 that many model 
fitting procedures are based on empirical risk minimization, which involve minimizing the following 
loss: 


epee i= 
L(81) = = Do Uun, fans )) = 55 DY, Ln (Or) (8.55) 


This is called a finite sum problem. The gradient of this objective has the form 


aD VoL (0:1) = i5 Vol(Yn, f (an; )) (8.56) 


n=1 


This requires summing over all N training examples, and thus can be slow if N is large. Fortunately 
we can approximate this by sampling a minibatch of B < N samples to get 


1 
~ 5] XO Voln(0:) = 5 i XO Vollyn, f (2n; 91) (8.57) 


nEB: nEBbB: 


where B; is a set of randomly chosen examples to use at iteration t.? This is an unbiased approximation 
to the empirical average in Equation (8.56). Hence we can safely use this with SGD. 

Although the theoretical rate of convergence of SGD is slower than batch GD (in particular, SGD 
has a sublinear convergence rate), in practice SGD is often faster, since the per-step time is much 
lower [BB08; BB11]. To see why SGD can make faster progress than full batch GD, suppose we have 
a dataset consisting of a single example duplicated K times. Batch training will be (at least) K times 
slower than SGD, since it will waste time computing the gradient for the repeated examples. Even if 
there are no duplicates, batch training can be wasteful, since early on in training the parameters are 
not well estimated, so it is not worth carefully evaluating the gradient. 


8.4.2 Example: SGD for fitting linear regression 


In this section, we show how to use SGD to fit a linear regression model. Recall from Section 4.2.7 
that the objective has the form 


N 


2 Sales ike ae 
£(8) = yD (#n9 - Yn)” = zy llXO — vll (8.58) 


n=1 
The gradient is 


N 
g= N DOEN = Yn)Tn (8.59) 


n=1 
2. In practice we usually sample By without replacement. However, once we reach the end of the dataset (i.e., after a 


single training epoch), we can perform a random shuffling of the examples, to ensure that each minibatch on the next 
epoch is different from the last. This version of SGD is analyzed in [HS19]. 
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Figure 8.16: Illustration of the LMS algorithm. Left: we start from 0 = (—0.5,2) and slowly converging to 
the least squares solution of O = (1.45,0.93) (red cross). Right: plot of objective function over time. Note 
that it does not decrease monotonically. Generated by lms_ demo.ipynb. 


Now consider using SGD with a minibatch size of B = 1. The update becomes 
Or41 om 0; = mlO En a Yn)En, (8.60) 


where n = n(t) is the index of the example chosen at iteration t. The overall algorithm is called the 
least mean squares (LMS) algorithm, and is also known as the delta rule, or the Widrow-Hoff 
rule. 

Figure 8.16 shows the results of applying this algorithm to the data shown in Figure 11.2. We 
start at @ = (—0.5,2) and converge (in the sense that ||@; — 0:—1||3 drops below a threshold of 107?) 
in about 26 iterations. Note that SGD (and hence LMS) may require multiple passes through the 
data to find the optimum. 


8.4.3 Choosing the step size (learning rate) 


When using SGD, we need to be careful in how we choose the learning rate in order to achieve 
convergence. For example, in Figure 8.17 we plot the loss vs the learning rate when we apply SGD 
to a deep neural network classifier (see Chapter 13 for details). We see a U-shaped curve, where an 
overly small learning rate results in underfitting, and overly large learning rate results in instability 
of the model (c.f., Figure 8.11(b)); in both cases, we fail to converge to a local optimum. 

One heuristic for choosing a good learning rate, proposed in [Smil8], is to start with a small 
learning rate and gradually increase it, evaluating performance using a small number of minibatches. 
We then make a plot like the one in Figure 8.17, and pick the learning rate with the lowest loss. (In 
practice, it is better to pick a rate that is slightly smaller than (i.e., to the left of) the one with the 
lowest loss, to ensure stability.) 

Rather than choosing a single constant learning rate, we can use a learning rate schedule, in 
which we adjust the step size over time. Theoretically, a sufficient condition for SGD to achieve 
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Figure 8.17: Loss vs learning rate (horizontal axis). Training loss vs learning rate for a small MLP fit to 
FashionMNIST using vanilla SGD. (Raw loss in blue, EWMA smoothed version in orange). Generated by 
Irschedule_ tf.ipynb. 
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Figure 8.18: Illustration of some common learning rate schedules. (a) Piecewise constant. (b) Exponential 
decay. (c) Polynomial decay. Generated by learning rate plot.ipynb. 


convergence is if the learning rate schedule satisfies the Robbins-Monro conditions: 


œ 2 
m — 0, Lh o (8.61) 
ar 


Some common examples of learning rate schedules are listed below: 


me = ni if ti < t < ti+ı1 piecewise constant (8.62) 
m= noe ** exponential decay (8.63) 
me = No(Gt + 1)~* polynomial decay (8.64) 


In the piecewise constant schedule, t; are a set of time points at which we adjust the learning rate 
to a specified value. For example, we may set 7; = yoy‘, which reduces the initial learning rate by a 
factor of y for each threshold (or milestone) that we pass. Figure 8.18a illustrates this for ņo = 1 
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Figure 8.19: (a) Linear warm-up followed by cosine cool-down. (b) Cyclical learning rate schedule. 


and y = 0.9. This is called step decay. Sometimes the threshold times are computed adaptively, 
by estimating when the train or validation loss has plateaued; this is called reduce-on-plateau. 
Exponential decay is typically too fast, as illustrated in Figure 8.18b. A common choice is polynomial 
decay, with a = 0.5 and 8 = 1, as illustrated in Figure 8.18c; this corresponds to a square-root 
schedule, m = Noe 

In the deep learning community, another common schedule is to quickly increase the learning rate 
and then gradually decrease it again, as shown in Figure 8.19a. This is called learning rate warmup, 
or the one-cycle learning rate schedule [Smi18]. The motivation for this is the following: initially 
the parameters may be in a part of the loss landscape that is poorly conditioned, so a large step size 
will “bounce around” too much (c.f., Figure 8.11(b)) and fail to make progress downhill. However, 
with a slow learning rate, the algorithm can discover flatter regions of space, where a larger step size 
can be used. Once there, fast progress can be made. However, to ensure convergence to a point, we 
must reduce the learning rate to 0. See [Got-+19; Gil+21] for more details. 

It is also possible to increase and decrease the learning rate multiple times, in a cyclical fashion. 
This is called a cyclical learning rate [Smil8], and was popularized by the fast.ai course. See 
Figure 8.19b for an illustration using triangular shapes. The motivation behind this approach is to 
escape local minima. The minimum and maximum learning rates can be found based on the initial 
“dry run” described above, and the half-cycle can be chosen based on how many restarts you want to 
do with your training budget. A related approach, known as stochastic gradient descent with 
warm restarts, was proposed in [LH17]; they proposed storing all the checkpoints visited after each 
cool down, and using all of them as members of a model ensemble. (See Section 18.2 for a discussion 
of ensemble learning.) 

An alternative to using heuristics for estimating the learning rate is to use line search (Sec- 
tion 8.2.2.2). This is tricky when using SGD, because the noisy gradients make the computation of 
the Armijo condition difficult [CS20]. However, [Vas+-19] show that it can be made to work if the 
variance of the gradient noise goes to zero over time. This can happen if the model is sufficiently 
flexible that it can perfectly interpolate the training set. 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


8.4. Stochastic gradient descent 295 


8.4.4 Iterate averaging 
The parameter estimates produced by SGD can be very unstable over time. To reduce the variance 


of the estimate, we can compute the average using 


t 
= 1 3 1 t- 1- 
0 = T 0; = got + t O41 (8.65) 


i=1 


where 0; are the usual SGD iterates. This is called iterate averaging or Polyak-Ruppert 
averaging [Rup88]. 

In [P.J92], they prove that the estimate 0; achieves the best possible asymptotic convergence rate 
among SGD algorithms, matching that of variants using second-order information, such as Hessians. 

This averaging can also have statistical benefits. For example, in [NR18], they prove that, in the 
case of linear regression, this method is equivalent to ¢2 regularization (i.e., ridge regression). 

Rather than an exponential moving average of SGD iterates, Stochastic Weight Averaging 
(SWA) [Izm+18] uses an equal average in conjunction with a modified learning rate schedule. In 
contrast to standard Polyak-Ruppert averaging, which was motivated for faster convergence rates, 
SWA exploits the flatness in objectives used to train deep neural networks, to find solutions which 
provide better generalization. 


8.4.5 Variance reduction * 


In this section, we discuss various ways to reduce the variance in SGD. In some cases, this can 
improve the theoretical convergence rate from sublinear to linear (i.e., the same as full-batch gradient 
descent) [SLRB17; JZ13; DBLJ14]. These methods reduce the variance of the gradients, rather than 
the parameters themselves and are designed to work for finite sum problems. 


8.4.5.1 SVRG 


The basic idea of stochastic variance reduced gradient (SVRG) [JZ13] is to use a control 
variate, in which we estimate a baseline value of the gradient based on the full batch, which we then 
use to compare the stochastic gradients to. 

More precisely, ever so often (e.g., once per epoch), we compute the full gradient at a “snapshot” 
of the model parameters ð; the corresponding “exact” gradient is therefore VL(8). At step t, we 
compute the usual stochastic gradient at the current parameters, VL;,(0;), but also at the snapshot 
parameters, VL.(0), which we use as a baseline. We can then use the following improved gradient 
estimate 


gi = VLi(84) — V Lld) + VLO) (8.66) 


to compute 6+1. This is unbiased because E [V2.(6)| = VL(6). Furthermore, the update only 


involves two gradient computations, since we can compute VL(8) once per epoch. At the end of the 
epoch, we update the snapshot parameters, 6, based on the most recent value of 0+, or a running 
average of the iterates, and update the expected baseline. (We can compute snapshots less often, but 
then the baseline will not be correlated with the objective and can hurt performance, as shown in 
[DB18].) 
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Iterations of SVRG are computationally faster than those of full-batch GD, but SVRG can still 
match the theoretical convergence rate of GD. 


8.4.5.2 SAGA 


In this section, we describe the stochastic averaged gradient accelerated (SAGA) algorithm 
of [DBLJ14]. Unlike SVRG, it only requires one full batch gradient computation, at the start of 
the algorithm. However, it “pays” for this saving in time by using more memory. In particular, it 
must store N gradient vectors. This enables the method to maintain an approximation of the global 
gradient by removing the old local gradient from the overall sum and replacing it with the new local 
gradient. This is called an aggregated gradient method. 

More precisely, we first initialize by computing gl°°! = VL,,(@o) for all n, and the average, 


gt = + pean gi. Then, at iteration t, we use the gradient estimate 


gi = VLn(O;) — gio?! + g?¥é (8.67) 


where n ~ Unif{1,..., N} is the example index sampled at iteration t. We then update glec@! = 
VL,,(0;) and g?’8 by replacing the old gl°*! by its new value. 

This has an advantage over SVRG since it only has to do one full batch sweep at the start. (In 
fact, the initial sweep is not necessary, since we can compute g*¥® “lazily”, by only incorporating 
gradients we have seen so far.) The downside is the large extra memory cost. However, if the features 
(and hence gradients) are sparse, the memory cost can be reasonable. Indeed, the SAGA algorithm is 


recommended for use in the sklearn logistic regression code when N is large and æ is sparse.’ 


8.4.5.3 Application to deep learning 


Variance reduction methods are widely used for fitting ML models with convex objectives, such as 
linear models. However, there are various difficulties associated with using SVRG with conventional 
deep learning training practices. For example, the use of batch normalization (Section 14.2.4.1), data 
augmentation (Section 19.1) and dropout (Section 13.5.4) all break the assumptions of the method, 
since the loss will differ randomly in ways that depend not just on the parameters and the data index 
n. For more details, see e.g., [DB18; Arn+19]. 


8.4.6 Preconditioned SGD 


In this section, we consider preconditioned SGD, which involves the following update: 
6141 = 0 — mM; gt, (8.68) 


where M; is a preconditioning matrix, or simply the preconditioner, typically chosen to be 
positive-definite. Unfortunately the noise in the gradient estimates make it difficult to reliably 
estimate the Hessian, which makes it difficult to use the methods from Section 8.3. In addition, 
it is expensive to solve for the update direction with a full preconditioning matrix. Therefore 
most practitioners use a diagonal preconditioner M;. Such preconditioners do not necessarily use 
second-order information, but often result in speedups compared to vanilla SGD. See also [Roo+21] 


3. See https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression. 
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for a probabilitic interpretation of these heuristics, and sgd_comparison.ipynb for an empirical 
comparison on some simple datasets. 


8.4.6.1 ADAGRAD 
The ADAGRAD (short for “adaptive gradient”) method of [DHS11] was originally designed for 


optimizing convex objectives where many elements of the gradient vector are zero; these might 
correspond to features that are rarely present in the input, such as rare words. The update has the 
following form 


Or41,4 = Otd — M (8.69) 


1 


where d = 1: D indexes the dimensions of the parameter vector, and 


t 
Sta = > Gea (8.70) 
s= 


is the sum of the squared gradients and e€ > 0 is a small term to avoid dividing by zero. Equivalently 
we can write the update in vector form as follows: 


1 
=F 
yst HE 
where the square root and division is performed elementwise. Viewed as preconditioned SGD, this is 
equivalent to taking M; = diag(s; + e)" 2. This is an example of an adaptive learning rate; the 


overall stepsize 7, still needs to be chosen, but the results are less sensitive to it compared to vanilla 
GD. In particular, we usually fix m = no. 


8.4.6.2 RMSPrRopP and ADADELTA 


A defining feature of ADAGRAD is that the term in the denominator gets larger over time, so the 
effective learning rate drops. While it is necessary to ensure convergence, it might hurt performance 
as the denominator gets large too fast. 

An alternative is to use an exponentially weighted moving average (EWMA, Section 4.4.2.2) of the 
past squared gradients, rather than their sum: 


st+id = Bsta + (1—B) ga (8.72) 


In practice we usually use 6 ~ 0.9, which puts more weight on recent examples. In this case, 


V5t,a ~ RMS(g1:t,a) = (8.73) 


where RMS stands for “root mean squared”. Hence this method, (which is based on the earlier 
RPROP method of [RB93]) is known as RMSPRop [Hin14]. The overall update of RMSPROP 


1S 


1 
Jalie T zt 
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The ADADELTA method was independently introduced in [Zeil2], and is similar to RMSprop. 
However, in addition to accumulating an EWMA of the gradients in ŝ, it also keeps an EWMA of 
the updates 6; to obtain an update of the form 


VV Ot-1 + € 
AO, = — 9 (8.75) 


where 
ô; = 86:1 + (1 — B)(A®)? (8.76) 


and s+ is the same as in RMSPRop. This has the advantage that the “units” of the numerator and 
denominator cancel, so we are just elementwise-multiplying the gradient by a scalar. This eliminates 
the need to tune the learning rate m, which means one can simply set m = 1, although popular 
implementations of ADADELTA still keep 7, as a tunable hyperparameter. However, since these 
adaptive learning rates need not decrease with time (unless we choose m, to explicitly do so), these 
methods are not guaranteed to converge to a solution. 


8.4.6.3 ADAM 


It is possible to combine RMSPROP with momentum. In particular, let us compute an EWMA of 
the gradients (as in momentum) and squared gradients (as in RMSPRop) 


Mi = bıMi—ı + (1 — Bi) (8.77) 
sı = 281-1 + (1 — Bo)g? (8.78) 


We then perform the following update: 


1 
Vst tE 


The resulting method is known as ADAM, which stands for “adaptive moment estimation” [KB15]. 

The standard values for the various constants are 8; = 0.9, 62 = 0.999 and e = 1076. (If we set 
8, = 0 and no bias-correction, we recover RMSPROP, which does not use momentum.) For the 
overall learning rate, it is common to use a fixed value such as 7, = 0.001. Again, as the adaptive 
learning rate may not decrease over time, convergence is not guaranteed (see Section 8.4.6.4). 

If we initialize with mp = so = 0, then initial estimates will be biased towards small values. The 
authors therefore recommend using the bias-corrected moments, which increase the values early in 
the optimization process. These estimates are given by 


Ad; = — mM (8.79) 


Mu = m;/(1— bi) (8.80) 
8; = 8;/(1 — 83) (8.81) 


The advantage of bias-correction is shown in Figure 4.3. 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


8.5. Constrained optimization 299 


8.4.6.4 Issues with adaptive learning rates 


When using diagonal scaling methods, the overall learning rate is determined by noM;', which 
changes with time. Hence these methods are often called adaptive learning rate methods. However, 
they still require setting the base learning rate no. 

Since the EWMA methods are typically used in the stochastic setting where the gradient estimates 
are noisy, their learning rate adaptation can result in non-convergence even on convex problems 
[RKK18]. Various solutions to this problem have been proposed, including AMSGRaAD [RKK18], 
PADAM [CG18; Zho+18], and Yoat [Zah+18]. For example, the YOGI update modifies ADAM by 
replacing 


sı = Bost_1 + (1 — B2)g? = 81-1 + (1 — Bo) (g? — 84-1) (8.82) 
with 
St = 8-1 + (1— b2)g; © sen(g? — 8:1) (8.83) 


However, more recent work [Zha+22] has shown that vanilla Adam can be made to always converge 
provided the 3; and 62 parameters are tuned on a per-dataset basis. (In practice, it is common to 
fix 6; = 0.9 and just tune £2.) 


8.4.6.5 Non-diagonal preconditioning matrices 


Although the methods we have discussed above can adapt the learning rate of each parameter, they 
do not solve the more fundamental problem of ill-conditioning due to correlation of the parameters, 
and hence do not always provide as much of a speed boost over vanilla SGD as one may hope. 


One way to get faster convergence is to use the following preconditioning matrix, known as 
full-matrix Adagrad [DHS11]: 


M; = [(GiGq)? + Ip]! (8.84) 
where 
G; = (Ot; sea gı] (8.85) 


Here gi = Vyc(p;) is the D-dimensional gradient vector computed at step i. Unfortunately, M; is a 
D x D matrix, which is expensive to store and invert. 

The Shampoo algorithm [GKS18] makes a block diagonal approximation to M, one per layer 
of the model, and then exploits Kronecker product structure to efficiently invert it. (It is called 
“shampoo” because it uses a conditioner.) Recently, [Ani+20] scaled this method up to fit very large 
deep models in record time. 


8.5 Constrained optimization 


In this section, we consider the following constrained optimization problem: 


o* = arg min £(8) (8.86) 
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02 02 


(3/2 , 1/8) 


8 * = (1/2 , 1/2) 


0, 0, 


Figure 8.20: Illustration of some constrained optimization problems. Red contours are the level sets of the 
objective function L(0). Optimal constrained solution is the black dot, (a) Blue line is the equality constraint 
h(@) =0. (b) Blue lines denote the inequality constraints |01| + |02| < 1. (Compare to Figure 11.8 (left).) 


where the feasible set, or constraint set, is 
C= {0 ER? :h,(0) =0,1 € E, g;(0) < 0,7 € T} (8.87) 


where € is the set of equality constraints, and Z is the set of inequality constraints. 

For example, suppose we have a quadratic objective, L(0) = 0? + 63, subject to a linear equality 
constraint, h(@) = 1 — 0, — 62 = 0. Figure 8.20(a) plots the level sets of £, as well as the constraint 
surface. What we are trying to do is find the point @* that lives on the line, but which is closest to 
the origin. It is clear from the geometry that the optimal solution is 9 = (0.5,0.5), indicated by the 
solid black dot. 

In the following sections, we briefly describe some of the theory and algorithms underlying 
constrained optimization. More details can be found in other books, such as [BV04; NW06; Ber15; 
Ber16]. 


8.5.1 Lagrange multipliers 


In this section, we discuss how to solve equality contrained optimization problems. We initially 
assume that we have just one equality constraint, h(@) = 0. 

First note that for any point on the constraint surface, Vh(0) will be orthogonal to the constraint 
surface. To see why, consider another point nearby, 0 + €, that also lies on the surface. If we make a 
first-order Taylor expansion around @ we have 


h(0 + €) = h(0) +€'VA(O) (8.88) 


Since both 0 and 0 + € are on the constraint surface, we must have h(@) = h(0 + e) and hence 
e'Vh(0) ~ 0. Since e is parallel to the constraint surface, Vh(@) must be perpendicular to it. 

We seek a point 6* on the constraint surface such that £(@) is minimized. We just showed that it 
must satisfy the condition that Vh(6*) is orthogonal to the constraint surface. In addition, such a 
point must have the property that V£(@) is also orthogonal to the constraint surface, as otherwise 
we could decrease L(0) by moving a short distance along the constraint surface. Since both Vh(0) 
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and V£L(@) are orthogonal to the constraint surface at 6*, they must be parallel (or anti-parallel) to 
each other. Hence there must exist a constant A* € R such that 


VL(0*) = *Vh(8") (8.89) 


(We cannot just equate the gradient vectors, since they may have different magnitudes.) The constant 
X* is called a Lagrange multiplier, and can be positive, negative, or zero. This latter case occurs 
when VL(6*) = 0. 

We can convert Equation (8.89) into an objective, known as the Lagrangian, that we should find 
a stationary point of the following: 


L(0, A) = L(0) + Ah(O) (8.90) 
At a stationary point of the Lagrangian, we have 
Vo,rL(8,A) =0 => AVoh(O) = VL(O), h(@) =0 (8.91) 


This is called a critical point, and satisfies the original constraint h(@) = 0 and Equation (8.89). 
If we have m > 1 constraints, we can form a new constraint function by addition, as follows: 


L(0,) = £(0) + S Ajh;(0) (8.92) 


We now have D+m equations in D+m unknowns and we can use standard unconstrained optimization 
methods to find a stationary point. We give some examples below. 


8.5.1.1 Example: 2d Quadratic objective with one linear equality constraint 


Consider minimizing L(0) = 0? + 63 subject to the constraint that 01 + 02 = 1. 
(This is the problem illustrated in Figure 8.20(a).) The Lagrangian is 


L(01,02, A) = 67 + 02 + (0, +02 — 1) (8.93) 


We have the following conditions for a stationary point: 


EE E E EEE, (8.94) 
06, 
EE AE ee (8.95) 
062 
Ž 1(01,02,.) =6,+62-1=0 (8.96) 


From Equations 8.94 and 8.95 we find 20; = —A = 262, so 6; = 02. Also, from Equation (8.96), we 
find 20, = 1. So 6* = (0.5,0.5), as we claimed earlier. Furthermore, this is the global minimum since 
the objective is convex and the constraint is affine. 
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8.5.2 The KKT conditions 


In this section, we generalize the concept of Lagrange multipliers to additionally handle inequality 
constraints. 

First consider the case where we have a single inequality constraint g(@) < 0. To find the optimum, 
one approach would be to consider an unconstrained problem where we add the penalty as an infinite 
step function: 


£(0) = L(0) + col (g(8) > 0) (8.97) 


However, this is a discontinuous function that is hard to optimize. 
Instead, we create a lower bound of the form jug(@), where u > 0. This gives us the following 
Lagrangian: 


L(O, u) = L(0) + ug(0) (8.98) 


Note that the step function can be recovered using 


> o0 if g(0) > 0, 
L(0) = L(6, u) = 8.99 
(8) iy (9,4) a otherwise ( ) 
Thus our optimization problem becomes 
min max L(0, u) (8.100) 


0 u20 


Now consider the general case where we have multiple inequality constraints, g(@) < 0, and 
multiple equality constraints, h(@) = 0. The generalized Lagrangian becomes 


L(0, u, X) = L(0) + X` mg: (0) +X. Ajhj(8) (8.101) 
i j 
(We are free to change —\,;h; to +A;h,; since the sign is arbitrary.) Our optimization problem 
becomes 


min m L(0, p, X) (8.102) 


When £ and g are convex, then all critical points of this problem must satisfy the following criteria 
(under some conditions [BV04, Sec.5.2.3]): 


e All constraints are satisfied (this is called feasibility): 


g(0) <0, h(0)=0 (8.103) 
e The solution is a stationary point: 


VL(6*) +X wiVGi(6*) +X AGVh;(O*) = 0 (8.104) 
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(a) (b) 


Figure 8.21: (a) A convex polytope in 2d defined by the intersection of linear constraints. (b) Depiction of 
the feasible set as well as the linear objective function. The red line is a level set of the objective, and the 
arrow indicates the direction in which it is improving. We see that the optimal solution lies at a vertex of the 


polytope. 


e The penalty for the inequality constraint points in the right direction (this is called dual feasi- 
bility): 


p>o0 (8.105) 


e The Lagrange multipliers pick up any slack in the inactive constraints, i.e., either 4; = 0 or 
gi(O*) = 0, so 


HOg=0 (8.106) 
This is called complementary slackness. 


To see why the last condition holds, consider (for simplicity) the case of a single inequality 
constraint, g(@) < 0. Either it is active, meaning g(0) = 0, or it is inactive, meaning g(0) < 0. 
In the active case, the solution lies on the constraint boundary, and g(@) = 0 becomes an equality 
constraint; then we have VL = uVg for some constant u Æ 0, because of Equation (8.89). In the 
inactive case, the solution is not on the constraint boundary; we still have V£ = uVg, but now 
u=0. 

These are called the Karush-Kuhn-Tucker (KKT) conditions. If £ is a convex function, and 
the constraints define a convex set, the KKT conditions are sufficient for (global) optimality, as well 
as necessary. 


8.5.3 Linear programming 


Consider optimizing a linear function subject to linear constraints. When written in standard 
form, this can be represented as 


min cO st. A0<b,0>0 (8.107) 


The feasible set defines a convex polytope, which is a convex set defined as the intersection of 
half spaces. See Figure 8.21 (a) for a 2d example. Figure 8.21(b) shows a linear cost function that 
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decreases as we move to the bottom right. We see that the lowest point that is in the feasible set is a 
vertex. In fact, it can be proved that the optimum point always occurs at a vertex of the polytope, 
assuming the solution is unique. If there are multiple solutions, the line will be parallel to a face. 
There may also be no optima inside the feasible set; in this case, the problem is said to be infeasible. 


8.5.3.1 The simplex algorithm 
It can be shown that the optima of an LP occur at vertices of the polytope defining the feasible set 


(see Figure 8.21(b) for an example). The simplex algorithm solves LPs by moving from vertex to 
vertex, each time seeking the edge which most improves the objective. 

In the worst-case scenario, the simplex algorithm can take time exponential in D, although in 
practice it is usually very efficient. There are also various polynomial-time algorithms, such as the 
interior point method, although these are often slower in practice. 
8.5.3.2 Applications 


There are many applications of linear programming in science, engineering and business. It is also 
useful in some machine learning problems. For example, Section 11.6.1.1 shows how to use it to solve 
robust linear regression. It is also useful for state estimation in graphical models (see e.g., [SGJ11]). 
8.5.4 Quadratic programming 


Consider minimizing a quadratic objective subject to linear equality and inequality constraints. This 
kind of problem is known as a quadratic program or QP, and can be written as follows: 


1 
min 50 Ha +c'0 st. AO<b, CO=d (8.108) 
If H is positive semidefinite, then this is a convex optimization problem. 


8.5.4.1 Example: 2d quadratic objective with linear inequality constraints 
As a concrete example, suppose we want to minimize 


3 1 1 
L(0) = (0, — ay + (62 — a = z0 Ha + c'@ + const (8.109) 
where H = 2I and c = —(3, 1/4), subject to 
ð| + |@2| <1 (8.110) 


See Figure 8.20(b) for an illustration. 
We can rewrite the constraints as 


0i+602 <1, 0—02 <1, —01ı +02 <1, -01—02 <1 (8.111) 
which we can write more compactly as 


A0 <b (8.112) 
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where b = 1 and 


t 1 
1 = 
A=! i (8.113) 


—1 -l 


This is now in the standard QP form. 

From the geometry of the problem, shown in Figure 8.20(b), we see that the constraints corre- 
sponding to the two left faces of the diamond) are inactive (since we are trying to get as close to the 
center of the circle as possible, which is outside of, and to the right of, the constrained feasible region). 
Denoting g;(@) as the inequality constraint corresponding to row i of A, this means g3(0*) > 0 and 
g4(0*) > 0, and hence, by complementarity, 43 = už = 0. We can therefore remove these inactive 
constraints. 

From the KKT conditions we know that 


H0 +c+A'u=0 (8.114) 


Using these for the actively constrained subproblem, we get 


2 0 1 1 0, 3 
0 2 1 —1||& | [1/4 
oe a, ra a (8.115) 
1 -1 0 0 [bo 1 
Hence the solution is 
0, = (1,0)", u, = (0.625, 0.375, 0,0)" (8.116) 


Notice that the optimal value of 0 occurs at one of the vertices of the ¢; “ball” (the diamond shape). 


8.5.4.2 Applications 


There are several applications of quadratic programming in ML. For example, in Section 11.4, 
we discuss the lasso method for sparse linear regression, which amounts to optimizing L(w) = 
||Xw — y||3 + A||w||1, which can be reformulated into a QP. And in Section 17.3, we show how to 
use QP for SVMs (support vector machines). 


8.5.5 Mixed integer linear programming * 


Integer linear programming or ILP corresponds to minimizing a linear objective, subject to 
linear constraints, where the optimization variables are discrete integers instead of reals. In standard 
form, the problem is as follows: 


min cO st. AO<b,0>0,0EZ? (8.117) 


where Z is the set of integers. If some of the optimization variables are real-valued, it is called a 
mixed ILP, often called a MIP for short. (If all of the variables are real-valued, it becomes a 
standard LP.) 
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MIPs have a large number of applications, such as in vehicle routing, scheduling and packing. 
They are also useful for some ML applications, such as formally verifying the behavior of certain 
kinds of deep neural networks [And-+18], and proving robustness properties of DNNs to adversarial 
(worst-case) perturbations [TXT19]. 


8.6 Proximal gradient method * 
We are often interested in optimizing an objective of the form 
L(0) = £,(0) + £,(8) (8.118) 


where Ls is differentiable (smooth), and £, is convex but not necessarily differentiable (i.e., it may be 
non-smooth or “rough”). For example, £, might be the negative log likelihood (NLL), and £, might 
be an indicator function that is infinite if a constraint is violated (see Section 8.6.1), or £, might be 
the €; norm of some parameters (see Section 8.6.2), or Ly might measure how far the parameters are 
from a set of allowed quantized values (see Section 8.6.3). 

One way to tackle such problems is to use the proximal gradient method (see e.g., [PB+14; 
PSW15]). Roughly speaking, this takes a step of size 7 in the direction of the gradient, and then 
projects the resulting parameter update into a space that respects Ly. More precisely, the update is 
as follows 


O14 = Prox, c.,. (0; = MV Ls (O:)) (8.119) 


where prox, ;, (@) is the proximal operator of £, (scaled by 7) evaluated at 0: 
1 
prox, (@) = argmin (cole) + 5, l| — ai) (8.120) 
z 1) 


(The factor of $ is an arbitrary convention.) We can rewrite the proximal operator as solving a 
constrained optimization problem, as follows: 


prox, (8) = argmin£,(z) s.t. ||2—O|l2 <p (8.121) 


where the bound p depends on the scaling factor 7. Thus we see that the proximal projection 
minimizes the function while staying close to (i.e., proximal to) the current iterate. We give some 
examples below. 


8.6.1 Projected gradient descent 


Suppose we want to solve the problem 


argmin£,(0) s.t. EC (8.122) 
0 


where C is a convex set. For example, we may have the box constraints C = {0:1 < 0 < u}, 
where we specify lower and upper bounds on each element. These bounds can be infinite for certain 
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Figure 8.22: Illustration of projected gradient descent. w is the current parameter estimate, w’ is the update 
after a gradient step, and Pc(w') projects this onto the constraint set C. From https: //bit. ly/ 3eJ3BhZ 
Used with kind permission of Martin Jaggi. 


elements if we don’t want to constrain values along that dimension. For example, if we just want to 
ensure the parameters are non-negative, we set l4 = 0 and uq = œ for each dimension d. 

We can convert the constrained optimization problem into an unconstrained one by adding a 
penalty term to the original objective: 


L(0) = £,(8) + £L,(8) (8.123) 
where £,(6) is the indicator function for the convex set C, i.e., 


0 if@EC 


8.124 
œ ifO gC ae 


£,(8) = Ic(@) = 
We can use proximal gradient descent to solve Equation (8.123). The proximal operator for the 
indicator function is equivalent to projection onto the set C: 
projc(@) = argmin ||’ — 8||2 (8.125) 
a’eEc 
This method is known as projected gradient descent. See Figure 8.22 for an illustration. 


For example, consider the box constraints C = {9:1 < 0 < u}. The projection operator in this 
case can be computed elementwise by simply thresholding at the boundaries: 


lg if0a< la 
proje(@)a= § Oa ifla < Oa < ua (8.126) 
ua if 0g >œ ua 


For example, if we want to ensure all elements are non-negative, we can use 
proje(@) = 04 = [max(01,0),...,max(@p, 0)] (8.127) 


See Section 11.4.9.2 for an application of this method to sparse linear regression. 
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8.6.2 Proximal operator for €;-norm regularizer 


Consider a linear predictor of the form f(x; 0) = Fi Oata. If we have 04 = 0 for any dimension 
d, we ignore the corresponding feature za. This is a form of feature selection, which can be 
useful both as a way to reduce overfitting as well as way to improve model interpretability. We can 
encourage weights to be zero (and not just small) by penalizing the ¢; norm, 


D 
Ol, = X [Gal (8.128) 
d=1 


This is called a sparsity inducing regularizer. 
To see why this induces sparsity, consider two possible parameter vectors, one which is sparse, 
0 = (1,0), and one which is non-sparse, 6’ = (1/V2,1/V2). Both have the same ¢2 norm 


IE ON = 0/2, 1/V2)|13 = 1 (8.129) 


Hence l> regularization (Section 4.5.3) will not favor the sparse solution over the dense solution. 
However, when using @; regularization, the sparse solution is cheaper, since 


II, Ils = 1 < 1/2, 1/V2)|h = v2 (8.130) 


See Section 11.4 for more details on sparse regression. 
If we combine this regularizer with our smooth loss, we get 


L(0) = NLL(0) + Àll@lh (8.131) 


We can optimize this objective using proximal gradient descent. The key question is how to compute 
the prox operator for the function f(@) = ||0||ı. Since this function decomposes over dimensions d, 
the proximal projection can be computed componentwise. From Equation (8.120), with 7 = 1, we 
have 


1 1 
prox) s (0) = argmin |z| + an? — 6)? = argmin A|z| + a -0% (8.132) 


In Section 11.4.3, we show that the solution to this is given by 


0—A ifO>A 
prox) (9) = 4 0 if |A} <A (8.133) 
6+ if8<-A 


This is known as the soft thresholding operator, since values less than A in absolute value are 
set to 0 (thresholded), but in a continuous way. Note that soft thresholding can be written more 
compactly as 


SoftThreshold(0, A) = sign(@) (\@] — A), (8.134) 
where 6, = max(6,0) is the positive part of 0. In the vector case, we perform this elementwise: 
SoftThreshold(@, A) = sign(@) © (|0| — A), (8.135) 


See Section 11.4.9.3 for an application of this method to sparse linear regression. 
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8.6.3 Proximal operator for quantization 


In some applications (e.g., when training deep neural networks to run on memory-limited edge 
devices, such as mobile phones) we want to ensure that the parameters are quantized. For 
example, in the extreme case where each parameter can only be -1 or +1, the state space becomes 
C = {-1,4+1}”. 

Let us define a regularizer that measures distance to the nearest quantized version of the parameter 
vector: 


6) = inf |0 -0 Mi 
L, (0) a olla (8.136) 
(We could also use the £2 norm.) In the case of C = {—1,+1}”, this becomes 
D D 
£,(8)= >, inf, 10a [6olal = XO min{[ða — 1, 0a + 1} = [18 sien) (8.137) 
d=1 US d=1 


Let us define the corresponding quantization operator to be 
q(0) = proje(0) = argmin £, (0) = sign(@) (8.138) 


The core difficulty with quantized learning is that quantization is not a differentiable operation. A 
popular solution to this is to use the straight-through estimator, which uses the approximation 


PO} z oe (see e.g., [Yin+-19]). The corresponding update can be done in two steps: first compute the 


gradient vector at the quantized version of the current parameters, and then update the unconstrained 
parameters using this approximate gradient: 


ð; = proje (0+) = q(4) (8.139) 

9141 = bi — hV Ls (8+) (8.140) 
When applied to C = {—1,+1}”, this is known as the binary connect method [CBD15]. 

We can get better results using proximal gradient descent, in which we treat quantization as a 


regularizer, rather than a hard constraint; this is known as ProxQuant [BWL19]. The update 
becomes 


6, = proxy, (0; — mV Ls (0:+)) (8.141) 


In the case that C = {—1, +1}? , one can show that the proximal operator is a generalization of the 
soft thresholding operator in Equation (8.135): 

prox), (8) = SoftThreshold(@, A, sign(@)) (8.142) 

= sign(@) + sign(@ — sign(@)) © (|@ — sign(@)| — A)+ (8.143) 


This can be generalized to other forms of quantization; see [Yin+19] for details. 


8.6.4 Incremental (online) proximal methods 


Many ML problems have an objective function which is a sum of losses, one per example. Such 
problems can be solved incrementally; this is a special case of online learning. It is possible to 
extend proximal methods to this setting. For a probabilistic perspective on such methods (in terms 
of Kalman filtering), see [AEM18; Aky+19]. 
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8.7 Bound optimization * 


In this section, we consider a class of algorithms known as bound optimization or MM algorithms. 
In the context of minimization, MM stands for majorize-minimize. In the context of maximization, 
MM stands for minorize-maximize. We will discuss a special case of MM, known as expectation 
maximization or EM, in Section 8.7.2. 


8.7.1 The general algorithm 


In this section, we give a brief outline of MM methods. (More details can be found in e.g., [HL04; 
Mail5; SBP17; Nad+19].) To be consistent with the literature, we assume our goal is to maximize 
some function ¢(@), such as the log likelihood, wrt its parameters 0. The basic approach in MM 
algorithms is to construct a surrogate function Q(0, 0°) which is a tight lowerbound to ¢(@) such 
that Q(0, 0°) < €(@) and Q(0", 6") = ¢(6"). If these conditions are met, we say that Q minorizes £. 
We then perform the following update at each step: 


o'*! = aremax Q(6, 0’) (8.144) 
6 


This guarantees us monotonic increases in the original objective: 
£(a'*!) > Q(at*!, 6°) > Q(0", a") = 40") (8.145) 


where the first inequality follows since Q(6‘+!, 6’) is a lower bound on ¢(@**") for any 6’; the second 
inequality follows from Equation (8.144); and the final equality follows the tightness property. As a 
consequence of this result, if you do not observe monotonic increase of the objective, you must have 
an error in your math and/or code. This is a surprisingly powerful debugging tool. 

This process is sketched in Figure 8.23. The dashed red curve is the original function (e.g., the 
log-likelihood of the observed data). The solid blue curve is the lower bound, evaluated at 0°; this 
touches the objective function at 0°. We then set 6°*! to the maximum of the lower bound (blue 
curve), and fit a new bound at that point (dotted green curve). The maximum of this new bound 
becomes 0¢t?, etc. 

If Q is a quadratic lower bound, the overall method is similar to Newton’s method, which repeatedly 
fits and then optimizes a quadratic approximation, as shown in Figure 8.14(a). The difference is that 
optimizing Q is guaranteed to lead to an improvement in the objective, even if it is not convex, whereas 
Newton’s method may overshoot or lead to a decrease in the objective, as shown in Figure 8.24, since 
it is a quadratic approximation and not a bound. 


8.7.2 The EM algorithm 


In this section, we discuss the expectation maximization (EM) algorithm [DLR77; MK97], which 
is a bound optimization algorithm designed to compute the MLE or MAP parameter estimate for 
probability models that have missing data and/or hidden variables. We let yn be the visible 
data for example n, and Zn be the hidden data. 

The basic idea behind EM is to alternate between estimating the hidden variables (or missing 
values) during the E step (expectation step), and then using the fully observed data to compute the 
MLE during the M step (maximization step). Of course, we need to iterate this process, since the 
expected values depend on the parameters, but the parameters depend on the expected values. 
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Figure 8.23: Illustration of a bound optimization algorithm. Adapted from Figure 9.14 of [Bis06]. Generated 
by emLogLikelihoodMaz.ipynb. 
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(a) Overshooting. (b) Seeking the wrong root. 
Figure 8.24: The quadratic lower bound of an MM algorithm (solid) and the quadratic approximation of 
Newton’s method (dashed) superimposed on an empirical density esitmate (dotted). The starting point of 
both algorithms is the circle. The square denotes the outcome of one MM update. The diamond denotes the 


outcome of one Newton update. (a) Newton’s method overshoots the global maximum. (b) Newton’s method 
results in a reduction of the objective. From Figure 4 of [FT05]. Used with kind permission of Carlo Tomasi. 


In Section 8.7.2.1, we show that EM is an MM algorithm, which implies that this iterative procedure 
will converge to a local maximum of the log likelihood. The speed of convergence depends on the 
amount of missing data, which affects the tightness of the bound [XJ96; MD97; SRG03; KKS20]. 


8.7.2.1 Lower bound 
The goal of EM is to maximize the log likelihood of the observed data: 


N N 
¢(0) = X` log p(yn|@) = X` log £ PlYn, a) (8.146) 


where Yn are the visible variables and z, are the hidden variables. Unfortunately this is hard to 
optimize, since the log cannot be pushed inside the sum. 
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EM gets around this problem as follows. First, consider a set of arbitrary distributions qn(Zn) over 
each hidden variable zn. The observed data log likelihood can be written as follows: 


= Soon dant zn) nye fi (8.147) 


Using Jensen’s inequality (Equation (6.34)), we can push the log (which is a concave function) 
inside the expectation to get the following lower bound on the log likelihood: 


) 2 Dd, Zn) log a 2) Marni (8.148) 


Qn(Zn) 
= X. Eq, [log p(Yn, zn|0)] + Han) (8.149) 
L(6,an) 
= = 2a (8, dn) = £(8, {n}) = L(9, a.) (8.150) 


where H(q) is the entropy of probability distribution q, and L(0, {q,}) is called the evidence lower 
bound or ELBO, since it is a lower bound on the log marginal likelihood, log p(yı:n|0), also 
called the evidence. Optimizing this bound is the basis of variational inference, which we discuss in 
Section 4.6.8.3. 


8.7.2.2 E step 
We see that the lower bound is a sum of N terms, each of which has the following form: 


Pl ths ZnlO) 


L(9,dn) = X dn(Zn) log le (8.151) 
0 
=X el Zn ) log Ë P(Zn|Yn: 9)P(Ynl9) (8.152) 
z dn (Zn) 
= X qn(zn) log redya + 2al Zn) log p(yn|0) (8.153) 
= — Dru (qn(Zn) || P(ZnlYn, 9)) sið (8.154) 


where Dxt (q || p) = £, q(2) log a is the Kullback-Leibler divergence (or KL divergence for short) 
between probability distributions q and p. We discuss this in more detail in Section 6.2, but the key 
property we need here is that Dpi (q || p) > 0 and Dxt (q || p) = 0 iff q = p. Hence we can maximize 
the lower bound Ł(0,{qn}) wrt {qn} by setting each one to g* = p(zn|Yn, 0). This is called the E 


step. This ensures the ELBO is a tight lower bound: 


Ł(0, {44}) = X log p(yn|0) = £(8) (8.155) 


To see how this connects to bound optimization, let us define 
Q(0, 9°) = L(A, {p(zn|Yn; O°) }) (8.156) 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


8.7. Bound optimization * 313 


Then we have Q(0, 6°) < £(@) and Q(0*, 0") = (6°), as required. 

However, if we cannot compute the posteriors p(Zn|Yyn; 0") exactly, we can still use an approximate 
distribution q(Zn|Yn; 9°); this will yield a non-tight lower-bound on the log-likelihood. This generalized 
version of EM is known as variational EM [NH98]. See the sequel to this book, [Mur23], for details. 


8.7.2.3 M step 


In the M step, we need to maximize Ł(0, {q} }) wrt 8, where the qf are the distributions computed 
in the E step at iteration t. Since the entropy terms H(qn) are constant wrt 8, so we can drop them 
in the M step. We are left with 


4 (0) = 5 Sat (zn) [log plyn, Zn|9)| (8.157) 


This is called the expected complete data log likelihood. If the joint probability is in the 
exponential family (Section 3.4), we can rewrite this as 


(0) = SE [T (yn, zn) @ — A(®)] = X (E [T (Yn, 2n)]' 8 — A(0)) (8.158) 


n n 


where E [T (yn, Zn)] are called the expected sufficient statistics. 
In the M step, we maximize the expected complete data log likelihood to get 


git = ¥ D t l 0 a 
hin? at, log p(Yn, Zn! )] ey) 


In the case of the exponential family, the maximization can be solved in closed-form by matching the 
moments of the expected sufficient statistics. 

We see from the above that the E step does not in fact need to return the full set of posterior 
distributions {q(Zn)}, but can instead just return the sum of the expected sufficient statistics, 
Xn Eaten) [T (Yn, 2n)]- This will become clearer in the examples below. 


8.7.3 Example: EM for a GMM 


In this section, we show how to use the EM algorithm to compute MLE and MAP estimates of the 
parameters for a Gaussian mixture model (GMM). 


8.7.3.1 E step 


The E step simply computes the responsibility of cluster k for generating data point n, as estimated 
using the current parameter estimates 0: 


(t) (t) 

7 = p*(z = kly a) e Ty, P(Yn|9;, ) 
nk n n> © 50 
Xy Tki P(Yn| kl ) 
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8.7.3.2 M step 


The M step maximizes the expected complete data log likelihood, given by 


(6) =E £ log p(zn|7) + Zeol e) (8.161) 


= £ log (11 mi) +) log (Iros =) (8.162) 
n k n k 
= 2, 2 E [zng] log me + x 2 l [enk] log N (Yn | ox, Er) (8.163) 


= =D 2 rc log(m) a3 Dorn [log [Ex] + (Yn — He) Ep (Yn — Me)] + const (8.164) 


where Znk = I (zn = k) is a one-hot encoding of the categorical value zn. This objective is just a 
weighted version of the standard problem of computing the MLEs of an MVN (see Section 4.2.6). 
One can show that the new parameter estimates are given by 


(t) 
p= Xin Take (8.165) 
r 
k 
sty — Dn gn — we? gn — we?) 
k 7 lt) 
k 
A ro yny, (t+1)/, (t+1)\T 
= —~—a Hk (ue) (8.166) 
Tk 
where rË DDA rl) is the weighted number of points assigned to cluster k. The mean of cluster k is 


just the Kadn average of all points assigned to cluster k, and the covariance is proportional to the 
weighted empirical scatter matrix. 
The M step for the mixture weights is simply a weighted form of the usual MLE: 


(t) 


1 
y= Sor = a (8.167) 


8.7.3.3 Example 


An example of the algorithm in action is shown in Figure 8.25 where we fit some 2d data with a 
2 component GMM. The data set, from [Bis(06], is derived from measurements of the Old Faithful 
geyser in Yellowstone National Park. In particular, we plot the time to next eruption in minutes 
versus the duration of the eruption in minutes. The data was standardized, by removing the mean 
and dividing by the standard deviation, before processing; this often helps convergence. We start 
with w, = (-1,1), Xi = I, 4 = (1,-1), Xə = I. We then show the cluster assignments, and 
corresponding mixture components, at various iterations. 
For more details on applying GMMs for clustering, see Section 21.4.1. 
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Figure 8.25: Illustration of the EM for a GMM applied to the Old Faithful data. The degree of redness 
indicates the degree to which the point belongs to the red cluster, and similarly for blue; thus purple points have 
a roughly 50/50 split in their responsibilities to the two clusters. Adapted from [Bis06] Figure 9.8. Generated 
by mix_gauss_demo_ faithful. ipynb. 


8.7.3.4 MAP estimation 


Computing the MLE of a GMM often suffers from numerical problems and overfitting. To see why, 
suppose for simplicity that X; = o7I for all k. It is possible to get an infinite likelihood by assigning 
one of the centers, say Hp, to a single data point, say Yn, since then the likelihood of that data point 
is given by 


1 
N(Yn|bic = Yn, 71) = oad (8.168) 
k 


Hence we can drive this term to infinity by letting og — 0, as shown in Figure 8.26(a). We call this 
the “collapsing variance problem”. 

An easy solution to this is to perform MAP estimation. Fortunately, we can still use EM to find 
this MAP estimate. Our goal is now to maximize the expected complete data log-likelihood plus the 
log prior: 


5 yr Dog Tnk + >, rn log p (yn|Ox) | + log p(w) + 2 log p(0;,) (8.169) 
n k 


Note that the E step remains unchanged, but the M step needs to be modified, as we now explain. 
For the prior on the mixture weights, it is natural to use a Dirichlet prior (Section 4.6.3.2), 
m ~ Dir(a@), since this is conjugate to the categorical distribution. The MAP estimate is given by 


0) 
(t+) "h +O%—1 
= .1 
Tk N+, ar K (8.170) 


If we use a uniform prior, a, = 1, this reduces to the MLE. 
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p(x) 


fraction of times EM for GMM fails 
a 


(a) (b) 


Figure 8.26: (a) Illustration of how singularities can arise in the likelihood function of GMMs. Here K = 2, 
but the first mixture component is a narrow spike (with o1 ~ 0) centered on a single data point xı. Adapted 
from Figure 9.7 of [Bis06]. Generated by mix_gauss_ singularity.ipynb. (b) Illustration of the benefit of 
MAP estimation vs ML estimation when fitting a Gaussian mixture model. We plot the fraction of times 
(out of 5 random trials) each method encounters numerical problems vs the dimensionality of the problem, 
for N = 100 samples. Solid red (upper curve): MLE. Dotted black (lower curve): MAP. Generated by 
mix gauss mle_vs_map.ipynb. 


For the prior on the mixture components, let us consider a conjugate prior of the form 
P(My, Er) = NIW (up, Del M, X, X, S) (8.171) 


This is called the Normal-Inverse-Wishart distribution (see the sequel to this book, [Mur23], 
for details.) Suppose we set the hyper-parameters for u to be K= 0, so that the uw, are unregularized; 
thus the prior will only influence our estimate of Xp. In this case, the MAP estimates are given by 


pir) = pry (8.172) 
_ S(t) 
A ee a, a (8.173) 


Vir 4 D+2 


where ft, is the MLE for u, from Equation (8.165), and $; is the MLE for X, from Equation (8.166). 
Now we discuss how to set the prior covariance, §. One possibility (suggested in [FRO7, p163}]) is 
to use 


ww 


S= diag(s?,..., s3) (8.174) 


1 
K2/D 
where s? = (1/N) ee (@na— Za)” is the pooled variance for dimension d. The parameter 7 controls 
how strongly we believe this prior. The weakest prior we can use, while still being proper, is to set 
ï= D +2, so this is a common choice. 

We now illustrate the benefits of using MAP estimation instead of ML estimation in the context of 
GMMs. We apply EM to some synthetic data with N = 100 samples in D dimensions, using either 
ML or MAP estimation. We count the trial as a “failure” if there are numerical issues involving 
singular matrices. For each dimensionality, we conduct 5 random trials. The results are illustrated in 
Figure 8.26(b). We see that as soon as D becomes even moderately large, ML estimation crashes and 
burns, whereas MAP with an appropriate prior estimation rarely encounters numerical problems. 
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Figure 8.27: Left: N = 200 data points sampled from a misture of 2 Gaussians in 1d, with nk = 0.5, on = 5, 
bia = —10 and u2 = 10. Right: Likelihood surface p(D|u1, p2), with all other parameters set to their true 
values. We see the two symmetric modes, reflecting the unidentifiability of the parameters. Generated by 
gmm_lik_ surface_plot.ipynb. 


8.7.3.5 Nonconvexity of the NLL 


The likelihood for a mixture model is given by 


N K 
(0) = X` log bs PlYn, lo) (8.175) 


n=l 


In general, this will have multiple modes, and hence there will not be a unique global optimum. 

Figure 8.27 illustrates this for a mixture of 2 Gaussians in 1d. We see that there are two equally 
good global optima, corresponding to two different labelings of the clusters, one in which the left 
peak corresponds to z = 1, and one in which the left peak corresponds to z = 2. This is called the 
label switching problem; see Section 21.4.1.2 for more details. 

The question of how many modes there are in the likelihood function is hard to answer. There are 
K! possible labelings, but some of the peaks might get merged, depending on how far apart the uk 
are. Nevertheless, there can be an exponential number of modes. Consequently, finding any global 
optimum is NP-hard [Alo+09; Dri+04]. We will therefore have to be satisfied with finding a local 
optimum. To find a good local optimum, we can use Kmeans++ (Section 21.3.4) to initialize EM. 


8.8 Blackbox and derivative free optimization 


In some optimization problems, the objective function is a blackbox, meaning that its functional 
form is unknown. This means we cannot use gradient-based methods to optimize it. Instead, 
solving such problems require blackbox optimization (BBO) methods, also called derivative 
free optimization (DFO). 

In ML, this kind of problem often arises when performing model selection. For example, suppose 
we have some hyper-parameters, A € A, which control the type or complexity of a model. We often 
define the objective function L(A) to be the loss on a validation set (see Section 4.5.4). Since the 
validation loss depends on the optimal model parameters, which are computed using a complex 
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algorithm, this objective function is effectively a blackbox.* 

A simple approach to such problems is to use grid search, where we evaluate each point in the 
parameter space, and pick the one with the lowest loss. Unfortunately, this does not scale to high 
dimensions, because of the curse of dimensionality. In addition, even in low dimensions this can be 
expensive if evaluating the blackbox objective is expensive (e.g., if it first requires training the model 
before computing the validation loss). Various solutions to this problem have been proposed. See the 
sequel to this book, [Mur23], for details. 


8.9 Exercises 


Exercise 8.1 [Subderivative of the hinge loss function *| 
Let f(a) = (1 — z)+ be the hinge loss function, where (z)+ = max(0, z). What are Of (0), Of(1), and Of(2)? 


Exercise 8.2 [EM for the Student distribution] 


Derive the EM equations for computing the MLE for a multivariate Student distribution. Consider the case 
where the dof parameter is known and unknown separately. Hint: write the Student distribution as a scale 
mixture of Gaussians. 


4. If the optimal parameters are computed using a gradient-based optimizer, we can “unroll” the gradient steps, to 
create a deep circuit that maps from the training data to the optimal parameters and hence to the validation loss. We 
can then optimize through the optimizer (see e.g., [Fra+17]). However, this technique can only be applied in limited 
settings. 
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Linear Models 


9 Linear Discriminant Analysis 


9.1 Introduction 
In this chapter, we consider classification models of the following form: 


— cle. 6) — —PElY = c 9)Ply = 419) 
DUA) = epee ue (9.1) 


The term p(y = c|@) is the prior over class labels, and the term p(a|y = c,6) is called the class 
conditional density for class c. 

The overall model is called a generative classifier, since it specifies a way to generate the features 
x for each class c, by sampling from p(a|y = c,@). By contrast, a discriminative classifier directly 
models the class posterior p(y|x#,@). We discuss the pros and cons of these two approaches to 
classification in Section 9.4. 

If we choose the class conditional densities in a special way, we will see that the resulting posterior 
over classes is a linear function of æ, i.e., log p(y = c|x,@) = w'x + const, where w is derived from 
0. Thus the overall method is called linear discriminant analysis or LDA.! 


9.2 Gaussian discriminant analysis 


In this section, we consider a generative classifier where the class conditional densities are multivariate 
Gaussians: 


p(aly = c, 0) = N (z| ue, Be) (9.2) 
The corresponding class posterior therefore has the form 
ply = c|æ, 0) x TN (2| pe, Zc) (9.3) 


where Te = p(y = c|9) is the prior probability of label c. (Note that we can ignore the normalization 
constant in the denominator of the posterior, since it is independent of c.) We call this model 
Gaussian discriminant analysis or GDA. 


1. This term is rather confusing for two reasons. First, LDA is a generative, not discriminative, classifier. Second, 
LDA also stands for “latent Dirichlet allocation”, which is a popular unsupervised generative model for bags of words 
[BNJ03]. 
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Figure 9.1: (a) Some 2d data from 3 different classes. 
discrim_ analysis_ dboundaries_ plot2.ipynb. 


Figure 9.2: Gaussian discriminant analysis fit to data in Figure 9.1. (a) Unconstrained covariances in- 
duce quadratic decision boundaries. (b) Tied covariances induce linear decision boundaries. Generated by 
discrim_ analysis_ dboundaries_ plot2.ipynb. 


9.2.1 Quadratic decision boundaries 


From Equation (9.3), we see that the log posterior over class labels is given by 
1 1 Ty 1 
log p(y = cla, 0) = log Te = 2 log |27 Del a z — He) Da (x _ He) + const (9.4) 


This is called the discriminant function. We see that the decision boundary between any two classes, 
say c and c’, will be a quadratic function of æ. Hence this is known as quadratic discriminant 
analysis (QDA). 

For example, consider the 2d data from 3 different classes in Figure 9.la. We fit full covariance 
Gaussian class-conditionals (using the method explained in Section 9.2.4), and plot the results in 
Figure 9.1b. We see that the features for the blue class are somewhat correlated, whereas the features 
for the green class are independent, and the features for the red class are independent and isotropic 
(spherical covariance). In Figure 9.2a, we see that the resulting decision boundaries are quadratic 
functions of x. 
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Figure 9.3: Geometry of LDA in the 2 class case where 41 = Xo = I. 


9.2.2 Linear decision boundaries 


Now we consider a special case of Gaussian discriminant analysis in which the covariance matrices are 
tied or shared across classes, so X. = X. If X is independent of c, we can simplify Equation (9.4) 
as follows: 


1 
log p(y = cla, 0) = log Te — 5 (@ = u) ETH (a — He) + const (9.5) 
1 1 
= log Te — zE He +a’ D7", +const — ze Ee (9.6) 
— — E 
=7+a' B, +k (9.7) 


The final term is independent of c, and hence is an irrelevant additive constant that can be dropped. 
Hence we see that the discriminant function is a linear function of æ, so the decision boundaries will 
be linear. Hence this method is called linear discriminant analysis or LDA. See Figure 9.2b for 
an example. 


9.2.3 The connection between LDA and logistic regression 


In this section, we derive an interesting connection between LDA and logistic regression, which we 
introduced in Section 2.5.3. From Equation (9.7) we can write 


0 erty ewl] 
By =e) = yy Bu etre! E 3 evu [La] (28) 
c c 
where we = [Ye, Be]. We see that Equation (9.8) has the same form as the multinomial logistic 


regression model. The key difference is that in LDA, we first fit the Gaussians (and class prior) to 
maximize the joint likelihood p(x, y|@), as discussed in Section 9.2.4, and then we derive w from 0. 
By contrast, in logistic regression, we estimate w directly to maximize the conditional likelihood 
p(y|x, w). In general, these can give different results (see Exercise 10.3). 

To gain further insight into Equation (9.8), let us consider the binary case. In this case, the 
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posterior is given by 


ietn 1 
ply = 1|2,8) = eBie+n +4 eSye+r0 E 1 + e(Bo-B1)"@+(yo-71) (9.9) 
=0 ((6ı — Bo)'£ + (71 — 0)) (9.10) 
where o(7) refers to the sigmoid function. 
Now 
-w= -aE Hı + TOE Ho + log(™/70) (9.11) 
= a ~ po)" "(445 + po) + low ms /70) (9.12) 
So if we define 
w = Bı — Bo = B"H — Mo) (9.13) 
zo = 5 (bh + Ho) ~ (P1 — Bo) — Aare — (9.14) 
then we have wzo = —(71 — Yo), and hence 
p(y = Lar, 8) = o(w" (a — 20)) (9.15) 
This has the same form as binary logistic regression. Hence the MAP decision rule is 
g(x) =1ifw'e>c (9.16) 


where c = w' ao. If To = m7 = 0.5, then the threshold simplifies to c = w! (p, + Ho). 


To interpret this equation geometrianilly, suppose © = o7I. In this case, w = 0-7 (u; — Ho), 
which is parallel to a line joining the two centroids, wọ and u. So we can classify a point by 
projecting it onto this line, and then checking if the projection is closer to o or p4, as illustrated in 
Figure 9.3. The question of how close it has to be depends on the prior over classes. If 7, = 7, then 
£o = 5 (m + Ho), which is halfway between the means. If we make mı > 7, we have to be closer to 
Ho than halfway in order to pick class 0. And vice versa if 7 > mı. Thus we see that the class prior 
just changes the decision threshold, but not the overall shape of the decision boundary. (A similar 
argument applies in the multi-class case.) 


9.2.4 Model fitting 


We now discuss how to fit a GDA model using maximum likelihood estimation. The likelihood 
function is as follows 


p(D\@) = [I cov (ynlT) II N (Enl he, Ee) Yn =9) (9.17) 
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Hence the log-likelihood is given by 


log p(D|@) = SE Yn = c) log Te 


n=1c=1 


Slog M(an|te, Ze) (9.18) 


N:Yn =C 


c=1 


Thus we see that we can optimize m and the (ue, Xc) terms separately. 
From Section 4.2.4, we have that the MLE for the class prior is fe = Ne, Using the results from 
Section 4.2.6, we can derive the MLEs for the Gaussians as follows: 


X oan (9.19) 


N:Yn =C 


$= D (en AEn- Ae (9.20) 


N:Yn =C 


Unfortunately the MLE for $, can easily overfit (i.e., the estimate may not be well-conditioned) if 
N, is small compared to D, the dimensionality of the input features. We discuss some solutions to 
this below. 


9.2.4.1 Tied covariances 


If we force Xe = & to be tied, we will get linear decision boundaries, as we have seen. This also 
usually results in a more reliable parameter estimate, since we can pool all the samples across classes: 


Cc 
B= D En- AEn- he (9.21) 


9.2.4.2 Diagonal covariances 


If we force Xe to be diagonal, we reduce the number of parameters from O(C'D?) to O(CD), which 
avoids the overfitting problem. However, this loses the ability to capture correlations between the 
features. (This is known as the naive Bayes assumption, which we discuss further in Section 9.3.) 
Despite this approximation, this approach scales well to high dimensions. 

We can further restrict the model capacity by using a shared (tied) diagonal covariace matrix. 
This is called “diagonal LDA” [BL04]. 


9.2.4.3 MAP estimation 


Forcing the covariance matrix to be diagonal is a rather strong assumption. An alternative approach 
is to perform MAP estimation of a (shared) full covariance Gaussian, rather than using the MLE. 
Based on the results of Section 4.5.2, we find that the MAP estimate is 


Emap = Adiag(Èmie) + (1 — A) Ème (9.22) 


where A controls the amount of regularization. This technique is known as regularized discriminant 
analysis or RDA [HTF09, p656]. 
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9.2.5 Nearest centroid classifier 
If we assume a uniform prior over classes, we can compute the most probable class label as follows: 


g(x) = argmax log p(y = cla, 0) = argmin(a — u.) E7 (x — ue) (9.23) 


This is called the nearest centroid classifier, or nearest class mean classifier (NCM), since 
we are assigning x to the class with the closest e, where distance is measured using (squared) 
Mahalanobis distance. 

We can replace this with any other distance metric to get the decision rule 


g(x) = argmin d’ (æ, ue) (9.24) 


We discuss how to learn distance metrics in Section 16.2, but one simple approach is to use 


P(x, pe) = ||@ — Hell = (£ — He) (WWT) (£ — ue) = |[W(2 — uo)? (9.25) 
The corresponding class posterior becomes 


1 2 
nly = ee, u, W) = EPCAIW E = HI 
Do expl WE — po )lli) 
We can optimize W using gradient descent applied to the discriminative loss. This is called nearest 
class mean metric learning [Men+12]. The advantage of this technique is that it can be used for 
one-shot learning of new classes, since we just need to see a single labeled prototype jz, per class 
(assuming we have learned a good W already). 


(9.26) 


9.2.6 Fisher’s linear discriminant analysis * 


Discriminant analysis is a generative approach to classification, which requires fitting an MVN to 
the features. As we have discussed, this can be problematic in high dimensions. An alternative 
approach is to reduce the dimensionality of the features x € RP and then fit an MVN to the 
resulting low-dimensional features z € R*. The simplest approach is to use a linear projection 
matrix, z = Wz, where W is a K x D matrix. One approach to finding W would be to use principal 
components analysis or PCA (Section 20.1). However, PCA is an unsupervised technique that does 
not take class labels into account. Thus the resulting low dimensional features are not necessarily 
optimal for classification, as illustrated in Figure 9.4. 

An alternative approach is to use gradient based methods to optimize the log likelihood, derived 
from the class posterior in the low dimensional space, as we discussed in Section 9.2.5. 

A third approach (which relies on an eigendecomposition, rather than a gradient-based optimizer) 
is to find the matrix W such that the low-dimensional data can be classified as well as possible using 
a Gaussian class-conditional density model. The assumption of Gaussianity is reasonable since we 
are computing linear combinations of (potentially non-Gaussian) features. This approach is called 
Fisher’s linear discriminant analysis, or FLDA. 

FLDA is an interesting hybrid of discriminative and generative techniques. The drawback of this 
technique is that it is restricted to using K < C — 1 dimensions, regardless of D, for reasons that we 
will explain below. In the two-class case, this means we are seeking a single vector w onto which we 
can project the data. Below we derive the optimal w in the two-class case. We then generalize to 
the multi-class case, and finally we give a probabilistic interpretation of this technique. 
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(a) (b) 


Projection of points onto PCA vector Projection of points onto Fisher vector 


(c) (a) 


Figure 9.4: Linear disciminant analysis applied to two class dataset in 2d, representing (standardized) height 
and weight for male and female adults (a) PCA direction. (b) FLDA direction. (c) Projection onto PCA 
direction shows poor class separation. (d) Projection onto FLDA direction shows good class separation. 
Generated by fisher_lda_demo.ipynb. 


9.2.6.1 Derivation of the optimal 1d projection 


We now derive this optimal direction w, for the two-class case, following the presentation of [Bis06, 
Sec 4.1.4]. Define the class-conditional means as 


My = ~ 5 Tn, H2 = = ` Ln (9.27) 


N:Yn=l N:Yn=2 


Let my, = wu, be the projection of each mean onto the line w. Also, let zn = wl £n, be the 
projection of the data onto the line. The variance of the projected points is proportional to 


a= 5 (zn — mk}? (9.28) 


N:Yn =k 


The goal is to find w such that we maximize the distance between the means, mz — mı, while also 
ensuring the projected clusters are “tight”, which we can do by minimizing their variance. This 
suggests the following objective: 


(m2 — my)? 


9.29 
s? + 83 (9.29) 


J(w) = 
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We can rewrite the right hand side of the above in terms of w as follows 


w'Spw 
where Sz is the between-class scatter matrix given by 
Sp = (H2 — Hy) (M2 — m)" (9.31) 
and Sw is the within-class scatter matrix, given by 
Sw = 5 (En — H1)(En — m)" + 5 (En — Ho)(En — Ha)" (9.32) 
NiYn=l N:Yn=2 
To see this, note that 
w'Sgw = w! (u — 11) (M2 — My)" w = (ma — my) (m2 — mı) (9.33) 
and 
w'Sww= S> w'(@n—py)(@n — m) w+ 
niyn =l 
SO w (an — p)(En — py)" w (9.34) 
N:Yn=2 
= X (nom) SO (n-m) (9.35) 
miyn =l nNYn=2 


Equation (9.30) is a ratio of two scalars; we can take its derivative with respect to w and equate to 
zero. One can show (Exercise 9.1) that J(w) is maximized when 


Sgw = \Sww (9.36) 
where 
w'Spw 


Equation (9.36) is called a generalized eigenvalue problem. If Sw is invertible, we can convert it 
to a regular eigenvalue problem: 


S;'Spw = dw ee) 
However, in the two class case, there is a simpler solution. In particular, since 
Sgw = (Hs — My) (M2 — Hy)" w = (H — p1) (M2 — m1) (9.39) 
then, from Equation (9.38) we have 
A w = Sp} (m — #1) (m2 — m1) (9.40) 
w X S (He — Hy) (9.41) 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


9.2. Gaussian discriminant analysis 329 


PCA projection of vowel data to 2d FLDA projection of vowel data to 2d 
2 0.04 hs K 
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-2 


-3 


(a) (b) 


Figure 9.5: (a) PCA projection of vowel data to 2d. (b) FLDA projection of vowel data to 2d. We see 
there is better class separation in the FLDA case. Adapted from Figure 4.11 of [HTF09]. Generated by 
fisher_discrim_ vowel. ipynb. 


Since we only care about the directionality, and not the scale factor, we can just set 


w = Sy (M2 — H1) (9.42) 


This is the optimal solution in the two-class case. If Sw « I, meaning the pooled covariance matrix 
is isotropic, then w is proportional to the vector that joins the class means. This is an intuitively 
reasonable direction to project onto, as shown in Figure 9.3. 


9.2.6.2 Extension to higher dimensions and multiple classes 


We can extend the above idea to multiple classes, and to higher dimensional subspaces, by finding a 
projection matrix W which maps from D to K. Let zn = Wzn be the low dimensional projection 
of the n’th data point. Let me = + Pog Zn be the corresponding mean for the c’th class and 


m= + SL Neme be the overall mean, both in the low dimensional space. We define the following 
scatter matrices: 


Sw= >> J (2n mM.)(zn — Me)" (9.43) 
Sz = 5 > N.(m,—m)(m, — m)" (9.44) 
Finally, we define the objective function as maximizing the following:? 


(w) = [Ss] _ [WTS W] 
ISw|  [WTSw W| 


(9.45) 


2. An alternative criterion that is sometimes used [Fuk90] is J(W) = tr {575s} = tr {(WSw W')-1(WS,W')}. 
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where Sw and Sz are defined in the original high dimensional space in the obvious way (namely using 
Ly, instead of Zn, He instead of me, and p instead of m). The solution can be shown [DHS01] to be 


W = Sy WU, where U are the K leading eigenvectors of Syy 28 ss. , assuming Sy is non-singular. 
(If it is Sonda we can first perform PCA on all the data. y 

Figure 9.5 gives an example of this method applied to some D = 10 dimensional speech data, 
representing C = 11 different vowel sounds. We project to K = 2 dimensions in order to visualize 
the data. We see that FLDA gives better class separation than PCA. 

Note that FLDA is restricted to finding at most a K < C — 1 dimensional linear subspace, no 
matter how large D, because the rank of the between class scatter matrix Sg is C — 1. (The -1 term 
arises because of the u term, which is a linear function of the ps,.) This is a rather severe restriction 
which limits the usefulness of FLDA. 


9.3 Naive Bayes classifiers 


In this section, we discuss a simple generative approach to classification in which we assume the features 
are conditionally independent given the class label. This is called the naive Bayes assumption. 
The model is called “naive” since we do not expect the features to be independent, even conditional on 
the class label. However, even if the naive Bayes assumption is not true, it often results in classifiers 
that work well [DP97; HYOla]. One reason for this is that the model is quite simple (it only has 
O(CD) parameters, for C classes and D features), and hence it is relatively immune to overfitting. 

More precisely, the naive Bayes assumption corresponds to using a class conditional density of the 
following form: 


D 
p(aly = ¢,8) = [| p(aaly = c, Bac) (9.46) 
d=1 


where Oac are the parameters for the class conditional density for class c and feature d. Hence the 
posterior over class labels is given by 


D 
= cr —, plzaly = c, Oac 
ply = cla, 0) = ply | r) Lama wt aly < ) (9.47) 
Ve py = ¢ |) a= P(waly = e, Oac) 


where 7, is the prior probability of class c, and 8 = (7, {04c}) are all the parameters. This is known 
as a naive Bayes classifier or NBC. 


9.3.1 Example models 


We still need to specify the form of the probability distributions in Equation (9.46). This depends on 
what type of feature xq is. We give some examples below: 


e In the case of binary features, rq € {0,1}, we can use the Bernoulli distribution: p(aly = c, 0) = 
em Ber(q|0ac), where ac is the probability that za = 1 in class c. This is sometimes called the 
multivariate Bernoulli naive Bayes model. For example, Figure 9.6 shows the estimated 
parameters for each class when we fit this model to a binarized version of MNIST. This approach 
does surprisingly well, and has a test set accuracy of 84.3%. (See Figure 9.7 for some sample 
predictions.) 
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Figure 9.6: Visualization of the Bernoulli class conditional densities for a naive Bayes classifier fit to a 
binarized version of the MNIST dataset. Generated by naive_bayes_mnist_jax.ipynb. 


1 4 9 4 9 
Figure 9.7: Visualization of the predictions made by the model in Figure 9.6 when applied to some 


binarized MNIST test images. The title shows the most probable predicted class. Generated by 
naive_bayes_mnist_jax.ipynb. 


e In the case of categorical features, rq € {1,..., K}, we can use the categorical distribution: 
plæ|y =c,0) = x Cat(xq|@ac), where Qacx is the probability that z4 = k given that y = c. 


e In the case of real-valued features, xg € R, we can use the univariate Gaussian distribution: 
p(aly = c¢,0) = te N (alae, 04.), Where pac is the mean of feature d when the class label is 
c, and 0%, is its variance. (This is equivalent to Gaussian discriminant analysis using diagonal 
covariance matrices.) 


9.3.2 Model fitting 


In this section, we discuss how to fit a naive Bayes classifier using maximum likelihood estimation. 
We can write the likelihood as follows: 


N D 
p(D\a) = | stim [Loti (9.48) 
n=1 d=1 
DC 
-JI [co mlr) | [ Tl renal8a =e] (9.49) 
n=1 d=1c=1 
so the log-likelihood is given by 
C D 
log p(D|@) = 22 X Ilm =) logme}| +A YIS enna) (9.50) 
n=1 c=1 c=1 d=1 Ln:yn=c 


We see that this decomposes into a term for m, and CD terms for each Oac: 


log p(D|O) = log p(Dy|r) + $ X. log p(Dac|Oac) (9.51) 
e d 


where Dy = {yn : n = 1 : N} are all the labels, and Dae = {£na : Yn = C} are all the values of feature 
d for examples from class c. Hence we can estimate these parameters separately. 
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In Section 4.2.4, we show that the MLE for z is the vector of empirical counts, îe = Ne. The 
MLEs for 4c depend on the choice of the class conditional density for feature d. We discuss some 
common choices below. 


e In the case of discrete features, we can use a categorical distribution. A straightforward extension 
of the results in Section 4.2.4 gives the following expression for the MLE: 


(9.52) 


where Ndek = al (Zna = k, Yn = c) is the number of times that feature d had value k in 
examples of class c. 


e In the case of binary features, the categorical distribution becomes the Bernoulli, and the MLE 
becomes 
hj Nac 


Bae = T (9.53) 


which is the empirical fraction of times that feature d is on in examples of class c. 


e In the case of real-valued features, we can use a Gaussian distribution. A straightforward extension 
of the results in Section 4.2.5 gives the following expression for the MLE: 


: 1 

fae = T pune (9.54) 

: 1 z 

ie = N. 5 (fnd n Bac)” (9.55) 
© niyn =c 


Thus we see that fitting a naive Bayes classifier is extremely simple and efficient. 


9.3.3 Bayesian naive Bayes 


In this section, we extend our discussion of MLE estimation for naive Bayes classifiers from Section 9.3.2 
to compute the posterior distribution over the parameters. For simplicity, let us assume we have 
categorical features, so p(%q|Oac) = Cat(aq|@ac), where Back = p(ta = k|y = c). In Section 4.6.3.2, 
we show that the conjugate prior for the categorical likelihood is the Dirichlet distribution, p(@a-) = 
Dir(Oac|Ba-); where Back can be interpereted as a set of “pseudo counts”, corresponding to counts 
Nack that come from prior data. Similarly we use a Dirichlet prior for the label frequencies, 
p(w) = Dir(az|a). By using a conjugate prior, we can compute the posterior in closed form, as we 
explain in Section 4.6.3. In particular, we have 


DC 
p(O|D) = Dir(7| &) ] [ ] [ Dir(@ac| Bac) (9.56) 


d=1c=1 


where A-=Ae +N. and Back=P ack +Nack: 
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Using the results from Section 4.6.3.4, we can derive the posterior predictive distribution as 
follows. For the label prior (before seeing x, but after seeing D), we have p(y|D) = Cat(y|7), where 
Te =A, / >>, Qe. For the feature likelihood of x (given y and D), we have p(tq = k|y = c, D) = back, 
where 


Back _ Back +Nack 
K 9 =K = 
Dai back Ypi Packs +Nack! 


is the posterior mean of the parameters. (Note that Sri Nackt = Nac = Ne is the number of 
examples for class c.) 

If Back= 0, this reduces to the MLE in Equation (9.52). By contrast, if we set B4.,4= 1, we add 
1 to all the empirical counts before normalizing. This is called add-one smoothing or Laplace 
smoothing. For example, in the binary case, this gives 


A Baci Naci — 14+ Naci 


Back = (9.57) 


Ode = = — = 9.58 
Baco +Naco+ Bacı +Nac1 2+ Nac ( ) 
We can finally compute the posterior predictive distribution over the label as follows: 
— —-I(xa=k) 
ply = cæ, D) x ply = cD) | [ plealy = c, D) = re | | | Tacx (9.59) 


d d k 
This gives us a fully Bayesian form of naive Bayes, in which we have integrated out all the parameters. 
(In this case, the predictive distribution can be obtained merely by plugging in the posterior mean 
parameters.) 


9.3.4 The connection between naive Bayes and logistic regression 


In this section, we show that the class posterior p(y|x,0) for a NBC model has the same form as 
multinomial logistic regression. For simplicity, we assume that the features are all discrete, and each 
has K states, although the result holds for arbitrary feature distributions in the exponential family. 

Let zak = I (z4 = k), so £4 is a one-hot encoding of feature d. Then the class conditional density 
can be written as follows: 


D D K 
p(aly = c,0) = Il Cat(xaly = ¢,0) = II II 0 (9.60) 


d=1 d=1k=1 


Hence the posterior over classes is given by 


Te hs exp|log me + Lak log ack] 
ply = cla, 0) = Il, II, d x = ea ee (9.61) 
deo Te Mal Pact, Xe exP[log te + Dog dig Lak log acre] 
This can be written as a softmax 
eBet+Ye 
ply = c|z, 0) = =z (9.62) 


PE eB THY 
d= 


by suitably defining 6. and ye. This has exactly the same form as multinomial logistic regression in 
Section 2.5.3. The difference is that with naive Bayes we optimize the joint likelihood Į [„ p(yn, £nl|0), 
whereas with logistic regression, we optimize the conditional likelihood Į [„ p(Yn|@n,@). In general, 
these can give different results (see Exercise 10.3). 
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1.24 
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Figure 9.8: The class-conditional densities p(x|y = c) (left) may be more complex than the class posteriors 
ply = c|x) (right). Adapted from Figure 1.27 of [Bis06]. Generated by generative VsDiscrim.ipynb. 


9.4 Generative vs discriminative classifiers 


A model of the form p(x, y) = p(y)p(x|y) is called a generative classifier, since it can be used 
to generate examples x from each class y. By contrast, a model of the form p(y|x) is called a 
discriminative classifier, since it can only be used to discriminate between different classes. Below 
we discuss various pros and cons of the generative and discriminative approaches to classification. 
(See also [BT04; UB05; LBM06; BLO7a; Rot+18].) 


9.4.1 Advantages of discriminative classifiers 
The main advantages of discriminative classifiers are as follows: 


e Better predictive accuracy. Discriminative classifiers are often much more accurate than 
generative classifiers [NJ02]. The reason is that the conditional distribution p(y|x) is often much 
simpler (and therefore easier to learn) than the joint distribution p(y, æ), as illustrated in Figure 9.8. 
In particular, discriminative models do not need to “waste effort” modeling the distribution of the 
input features. 


e Can handle feature preprocessing. A big advantage of discriminative methods is that they 
allow us to preprocess the input in arbitrary ways. For example, we can perform a polynomial 
expansion of the input features, and we can replace a string of words with embedding vectors (see 
Section 20.5). It is often hard to define a generative model on such pre-processed data, since the 
new features can be correlated in complex ways which are hard to model. 


e Well-calibrated probabilities. Some generative classifiers, such as naive Bayes (described in 
Section 9.3), make strong independence assumptions which are often not valid. This can result 
in very extreme posterior class probabilities (very near 0 or 1). Discriminative models, such as 
logistic regression, are often better calibrated in terms of their probability estimates, although 
they also sometimes need adjustment (see e.g., [NMC05]). 
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9.4.2 Advantages of generative classifiers 


The main advantages of generative classifiers are as follows: 


e Easy to fit. Generative classifiers are often very easy to fit. For example, in Section 9.3.2, we 
show how to fit a naive Bayes classifier by simple counting and averaging. By contrast, logistic 
regression requires solving a convex optimization problem (see Section 10.2.3 for the details), and 
neural nets require solving a non-convex optimization problem, both of which are much slower. 


e Can easily handle missing input features. Sometimes some of the inputs (components of a) 
are not observed. In a generative classifier, there is a simple method for dealing with this, as we 
show in Section 1.5.5. However, in a discriminative classifier, there is no principled solution to 
this problem, since the model assumes that x is always available to be conditioned on. 


e Can fit classes separately. In a generative classifier, we estimate the parameters of each class 
conditional density independently (as we show in Section 9.3.2), so we do not have to retrain 
the model when we add more classes. In contrast, in discriminative models, all the parameters 
interact, so the whole model must be retrained if we add a new class. 


e Can handle unlabeled training data. It is easy to use generative models for semi-supervised 
learning, in which we combine labeled data Dz, = {(£n, Yn)} and unlabeled data, Dz = {a}. 
However, this is harder to do with discriminative models, since there is no uniquely optimal way 
to exploit Dz. 


e May be more robust to spurious features. A discriminative model p(y|a) may pick up on 
features of the input x that can discriminate different values of y in the training set, but which are 
not robust and do not generalize beyond the training set. These are called spurious features 
(see e.g., [Arj21; Zho+21]). By contrast, a generative model p(a|y) may be better able to capture 
the causal mechanisms of the underlying data generating process; such causal models can be more 
robust to distribution shift (see e.g., [Sch19; LBS19; LN81]). 


9.4.3 Handling missing features 


Sometimes we are missing parts of the input a during training and/or testing. In a generative 
classifier, we can handle this situation by marginalizing out the missing values. (We assume that 
the missingness of a feature is not informative about its potential value.) By contrast, when using 
a discriminative model, there is no unique best way to handle missing inputs, as we discuss in 
Section 1.5.5. 

For example, suppose we are missing the value of xı. We just have to compute 


ply = cl@2:p,8) x ply = cl) p(x2:p|y = c, 8) (9.63) 
= p(y =elr) X` p(a1, £2:ply = ¢, 9) (9.64) 


£1 


In Gaussian discriminant analysis, we can marginalize out xı using the equations from Section 3.2.3. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


336 Chapter 9. Linear Discriminant Analysis 


If we make the naive Bayes assumption, things are even easier, since we can just ignore the 
likelihood term for xı. This follows because 


D D 
> p21, 22:dly = ¢, 8) = [Ere] TJ palac) = TT p(wal ac) (9.65) 
zı d=2 


Tı d=2 


where we exploited the fact that p(valy = c, 0) = p(al@ac) and >), p(x1|O1¢) = 1. 


9.5 Exercises 


Exercise 9.1 [Derivation of Fisher’s linear discriminant] 


w' Spw 
wl Sww 


Show that the maximum of J(w) = is given by Sew = ASww 


. Hint: recall that the derivative of a ratio of two scalars is given by a f = = ube 


T 
_ w Sgw 
where \ = TS yw 


where f’ = f(x) and g' = g(x). Also, recall that 4a" Aw = (A+A*)z. 
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1 O Logistic Regression 


10.1 Introduction 


Logistic regression is a widely used discriminative classification model p(y|a;0), where x € RP 
is a fixed-dimensional input vector, y € {1,...,C} is the class label, and 8 are the parameters. If 
C = 2, this is known as binary logistic regression, and if C > 2, it is known as multinomial 
logistic regression, or alternatively, multiclass logistic regression. We give the details below. 


10.2 Binary logistic regression 


Binary logistic regression corresponds to the following model 
p(ylw,0) = Ber(ylo(w" a + b)) (10.1) 
where ø is the sigmoid function defined in Section 2.4.2, w are the weights, b is the bias, and 
0 = (w, b) are all the parameters. In other words, 
1 
ply = Ie, @) = o(a) = 7 (10.2) 


where a = wl æ +b = log(z45), is the log-odds, and p = p(y = 1|x,0). (In ML, the quantity a is 
usually called the logit or the pre-activation.) 

Sometimes we choose to use the labels y € {—1, +1} instead of y € {0,1}. We can compute the 
probability of these alternative labels using 


p(y\x,@) = o(ya) (10.3) 


since o(—a) = 1 — o(a). This slightly more compact notation is widely used in the ML literature. 


10.2.1 Linear classifiers 


The sigmoid gives the probability that the class label is y = 1. If the loss for misclassifying each class 
is the same, then the optimal decision rule is to predict y = 1 iff class 1 is more likely than class 0, as 
we explained in Section 5.1.2.2. Thus 


f(x) =I (p(y = 1x) > p(y = Ola)) = 1 (tog HT) > 0) =I(a>0) (10.4) 
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Figure 10.1: (a) Visualization of a 2d plane in a 3d space with surface normal w going through point 
Zo = (Xo, yo, Zo). See text for details. (b) Visualization of optimal linear decision boundary induced by logistic 
regression on a 2-class, 2-feature version of the iris dataset. Generated by iris_ logreg.ipynb. Adapted from 
Figure 4.24 of [Gér19]. 


where a = w'a + b. 
Thus we can write the prediction function as follows: 


D 

f(w;0) =b+ wa =b+Y wata (10.5) 
d=1 

where w'a = (w,2) is the inner product between the weight vector w and the feature vector æ. 


This function defines a linear hyperplane, with normal vector w € R? and an offset b € R from 
the origin. 

Equation (10.5) can be understood by looking at Figure 10.la. Here we show a plane in a 3d 
feature space going through the point æo with surface normal w. Points on the surface satisfy 
w! (a — xo) = 0. If we define b = —w!' ao, we can rewrite this as w'a +b = 0. This plane separates 
3d space into two half spaces. This linear plane is known as a decision boundary. If we can 
perfectly separate the training examples by such a linear boundary (without making any classification 
errors on the training set), we say the data is linearly separable. From Figure 10.1b, we see that 
the two-class, two-feature version of the iris dataset is not linearly separable. 

In general, there will be uncertainty about the correct class label, so we need to predict a probability 
distribution over labels, and not just decide which side of the decision boundary we are on. In 
Figure 10.2, we plot p(y = 1|(x1, £2), w) = o (w1x£1 + w272) for different weight vectors w. The vector 


w defines the orientation of the decision boundary, and its magnitude, ||w|| = VEZ w3, controls 
the steepness of the sigmoid, and hence the confidence of the predictions. 


10.2.2 Nonlinear classifiers 


We can often make a problem linearly separable by preprocessing the inputs in a suitable way. In 
particular, let @(a) be a transformed version of the input feature vector. For example, suppose we 
use (£1, 22) = [1, 27, x2], and we let w = [—R?,1,1]. Then w'¢(x) = x? + x} — R?, so the decision 
boundary (where f(x) = 0) defines a circle with radius R, as shown in Figure 10.3. The resulting 
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Figure 10.2: Plots of o(wiz1+wer2). Here w = (wi, w2) defines the normal to the decision boundary. Points 
to the right of this have o(w'a) > 0.5, and points to the left have o(w'a) < 0.5. Adapted from Figure 39.3 
of [Mac03]. Generated by sigmoid_ 2d_plot.ipynb. 


Figure 10.3: Illustration of how we can transform a quadratic decision boundary into a linear one by 
transforming the features from æ = (a1,22) to p(x) = (a3, 23). Used with kind permission of Jean-Philippe 
Vert. 


function f is still linear in the parameters w, which is important for simplifying the learning problem, 
as we will see in Section 10.2.3. However, we can gain even more power by learning the parameters 
of the feature extractor (æ) in addition to linear weights w; we discuss how to do this in Part III. 

In Figure 10.3, we used a quadratic expansion of the features. We can also use a higher order 
polynomial, as in Section 1.2.2.2. In Figure 1.7, we show the effects of using polynomial expansion 
up to degree K on a 2d logistic regression problem. As in Figure 1.7, we see that the model becomes 
more complex as the number of parameters increases, and eventually results in overfitting. We discuss 
ways to reduce overfitting in Section 10.2.7. 
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10.2.3 Maximum likelihood estimation 

In this section, we discuss how to estimate the parameters of a logistic regression model using 
maximum likelihood estimation. 

10.2.3.1 Objective function 


The negative log likelihood (scaled by the dataset size N) is given by the following (we assume the 
bias term b is absorbed into the weight vector w): 


N 


1 1 
NLL(w) = — = log p(D|w) = — 5 log | [ Ber@nlun) (10.6) 
n=1 
N 
=.= 5 ‘wel xO oe 10.7 
= WD esl x ( Hn) *"| (10.7) 
n=1 
N 
1 
= N 5 [Yn log Ln + (1 ~ Yn) log(1 = Ln) ] (10.8) 
n=1 


N 
p2 e(Yns Hn) (10.9) 


where un = c (an) is the probability of class 1, an = w' æn is the logit, and Hee(yn, Hn) is the binary 
cross entropy defined by 


Hee(p, q) = — [plog q + (1 — p) log(1 — q)] (10.10) 
If we use Yn € F +1} instead of yn € {0,1}, then we can rewrite this as follows: 


NLL(w) = -FEN 1) log(o(an)) + I (Gn = —1) log(o(—an))] (10.11) 
zs D log(a O(Ynn)) (10.12) 

1 N 
=W 5 log(1 + exp(—Yn@n)) (10.13) 


However, in this book, we will mostly use the yn € {0,1} notation, since it is easier to generalize to 
the multiclass case (Section 10.3), and makes the connection with cross-entropy easier to see. 


10.2.3.2 Optimizing the objective 
To find the MLE, we must solve 
VwNLL(w) = g(w) = 0 (10.14) 


We can use any gradient-based optimization algorithm to solve this, such as those we discuss in 
Chapter 8. We give a specific example in Section 10.2.4. But first we must derive the gradient, as we 
explain below. 
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Figure 10.4: Polynomial feature expansion applied to a two-class, two-dimensional logistic regression problem. 
(a) Degree K = 1. (b) Degree K = 2. (c) Degree K = 4. (d) Train and test error vs degree. Generated by 
logreg_poly_ demo.ipynb. 


10.2.3.3 Deriving the gradient 


Although we can use automatic differentiation methods (Section 13.3) to compute the gradient of 
the NLL, it is also easy to do explicitly, as we show below. Fortunately the resulting equations will 
turn out to have a simple and intuitive interpretation, which can be used to derive other methods, as 
we will see. 

To start, note that 


dhun 

ae = 0(an)(1 — o(an)) (10.15) 
where an = w'a, and un = o(an). Hence by the chain rule (and the rules of vector calculus, 
discussed in Section 7.8) we have 


a ag a Jan 


Jwg” = m" Ln) = Aa aa 


= — un)£n 10.1 
Dwg T PeT Mn)end (10.16) 
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The gradient for the bias term can be derived in the same way, by using the input £no = 1 in the 
above equation. However, we will ignore the bias term for simplicity. Hence 


1 
Vw log(tn) = —VwHn = (1 — Hn)En (10.17) 
Hn 
Similarly, 
Pn 1— n n 
Vw log(1 — pn) = —" = ee Hnn (10.18) 


Thus the gradient vector of the NLL is given by 


N 
1 
VwNLL(w) = -W 5 [Yn (1 = Un)En — (1 = yn)HnEn] (10.19) 
n=1 
N 
1 
=N D [YnEn — YnLnHn — EnHn + YnTnHn)] (10.20) 
n=1 
1 N 
= N Se = Yn)Ln (10.21) 
n=1 


If we interpret en = Un — Yn as an error signal, we can see that the gradient weights each input £n 
by its error, and then averages the result. Note that we can rewrite the gradient in matrix form as 
follows: 


1 : 
VwNLL(w) = 55 (1 (diag(u — y)X))" (10.22) 
where X is the N x D design matrix containing the examples x, in each row. 


10.2.3.4 Deriving the Hessian 


Gradient-based optimizers will find a stationary point where g(w) = 0. This could either be a global 
optimum or a local optimum. To be sure the stationary point is the global optimum, we must show 
that the objective is convex, for reasons we explain in Section 8.1.1.1. Intuitvely this means that 
the NLL has a bowl shape, with a unique lowest point, which is indeed the case, as illustrated in 
Figure 10.5b. 

More formally, we must prove that the Hessian is positive semi-definite, which we now do. (See 
Chapter 7 for relevant background information on linear algebra.) One can show that the Hessian is 
given by 


N 
1 1 
= T _ y To T 
where 
S Ê diag(ui(1 — p1),---, (1 — ey) (10.24) 
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Loss function surface 


Figure 10.5: NLL loss surface for binary logistic regression applied to Iris dataset with 1 feature and 1 bias 
term. The goal is to minimize the function. The global MLE is at the center of the plot. Generated by 
iris logreg_loss_ surface.ipynb. 


We see that H is positive definite, since for any nonzero vector v, we have 
v' X'SXv = (v'X'S2)(S? Xv) = ||v'X™S2||2 > 0 (10.25) 


This follows since un > 0 for all n, because of the use of the sigmoid function. Consequently the 
NLL is strictly convex. However, in practice, values of un which are close to 0 or 1 might cause 
the Hessian to be close to singular. We can avoid this by using 2 regularization, as we discuss in 
Section 10.2.7. 


10.2.4 Stochastic gradient descent 


Our goal is to solve the following optimization problem 


w = argmin L(w) (10.26) 


w 


where £(w) is the loss function, in this case the negative log likelihood: 


N 
NLL(w pal Yn log Hn + (1 — Yn) log(1 _ Un )| (10.27) 


where un = o (an) is the probability of class 1, and an = w! æn is the log odds. 


There are many algorithms we could use to solve Equation (10.26), as we discuss in Chapter 8. 
Perhaps the simplest is to use stochastic gradient descent (Section 8.4). If we use a minibatch of size 
1, then we get the following simple update equation: 


Witi = Wi — HV wNLL(w;) = we — (Hn — Yn)En (10.28) 
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where we replaced the average over all N examples in the gradient of Equation (10.21) with a single 
stochastically chosen sample n. (The index n changes with t.) 

Since we know the objective is convex (see Section 10.2.3.4), then one can show that this procedure 
will converge to the global optimum, provided we decay the learning rate at the appropriate rate (see 
Section 8.4.3). We can improve the convergence speed using variance reduction techniques such as 
SAGA (Section 8.4.5.2). 


10.2.5 Perceptron algorithm 


A perceptron, first introduced in [Ros58], is a deterministic binary classifier of the following form: 
f(an;0) =1(w'a, +b > 0) (10.29) 


This can be seen to be a limiting case of a binary logistic regression classifier, in which the sigmoid 
function o(a) is replaced by the Heaviside step function H(a) = I (a > 0). See Figure 2.10 for a 
comparison of these two functions. 

Since the Heaviside function is not differentiable, we cannot use gradient-based optimization 
methods to fit this model. However, Rosenblatt proposed the perceptron learning algorithm 
instead. The basic idea is to start with random weights, and then iteratively update them whenever 
the model makes a prediction mistake. More precisely, we update the weights using 


Wt4+1 = Wt — m1 (Gn m Yn)&n (10.30) 


where (£n, Yn) is the labeled example sampled at iteration t, and m is the learning rate or step 
size. (We can set the step size to 1, since the magnitude of the weights does not affect the decision 
boundary.) See perceptron _demo_2d.ipynb for a simple implementation of this algorithm. 

The perceptron update rule in Equation (10.30) has an intuitive interpretation: if the prediction is 
correct, no change is made, otherwise we move the weights in a direction so as to make the correct 
answer more likely. More precisely, if yn = 1 and n = 0, we have w,41 = wi + £n, and if yn = 0 
and n = 1, we have wy41 = wy — Ln. 

By comparing Equation (10.30) to Equation (10.28), we see that the perceptron update rule is 
equivalent to the SGD update rule for binary logistic regression using the approximation where we 
replace the soft probabilities un = p(yn = 1|£n) with hard labels n = f (æn). The advantage of the 
perceptron method is that we don’t need to compute probabilities, which can be useful when the 
label space is very large. The disadvantage is that the method will only converge when the data is 
linearly separable [Nov62], whereas SGD for minimizing the NLL for logistic regression will always 
converge to the globally optimal MLE, even if the data is not linearly separable. 

In Section 13.2, we will generalize perceptrons to nonlinear functions, thus significantly enhancing 
their usefulness. 


10.2.6 Iteratively reweighted least squares 


Gradient descent is a first order optimization method, which means it only uses first order gradients 
to navigate through the loss landscape. This can be slow, especially when some directions of space 
point steeply downhill, whereas other have a shallower gradient, as is the case in Figure 10.5a. In 
such problems, it can be much faster to use a second order optimization method, that takes the 
curvature of the space into account. 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


10.2. Binary logistic regression 345 


We discuss such methods in more detail in Section 8.3. Here we just consider a simple second order 
method that works well for logistic regression. We focus on the full batch setting (so we assume N 
is small), since it is harder to make second order methods work in the stochastic setting (see e.g., 
[Byr+16; Liu+18b] for some methods). 

The classic second-order method is Newton’s method. This consists of updates of the form 


wie = wi — Hygi (10.31) 
where 
H; £ V?L(w)|w, = V? Lw) = H(w;) (10.32) 


is assumed to be positive-definite to ensure the update is well-defined. If the Hessian is exact, we can 
set the step size to 7 = 1. 

We now apply this method to logistic regression. Recall from Section 10.2.3.3 that the gradient 
and Hessian are given by 


N 
1 
VwNLL(w) = 5; 2 lun — Yn) kn (10.33) 
1 
H = —x'sx 10.34 
JAS (10.34) 
S Ê diag(um (1 — m), --- (1 — wy) (10.35) 


Hence the Newton update has the form 


wi = w, — H` tg; ( ) 
= w + (X'S;X) 1X" (y — p) ( ) 
= (X'S;X)~! [(X'S;X)w; + X! (y — H,)] (10.38) 
= (X'S,X) X" [S:Xw; + y — u] ( ) 
= (X'S;X) tX! S;z ( ) 


where we have defined the working response as 
zi & Xw, + S7 (y — p,) (10.41) 


and S; = diag( Ht n(1— Ht,n)). Since S; is a diagonal matrix, we can rewrite the targets in component 
form as follows: 


T Yn — Ht n 
Zin = W; En + — 10.42 
uals cra) aa 
Equation (10.40) is an example of a weighted least squares problem (Section 11.2.2.4), which is a 
minimizer of 


N 
XO Stn (Z1n — Wp Pn)? (10.43) 


n=1 
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Algorithm 10.1: Iteratively reweighted least squares (IRLS) 


1w=0 
2 repeat 
3 for n=1:N do 
4 An = wan 
5 Un = O(n) 
6 Sn = Hn(1 = pin) 
7 Zn = ün + E 
8 S = diag(s1.) 
w = (X'SX)-!X'Sz 


10 until converged 


The overall method is therefore known as the iteratively reweighted least squares (IRLS) 
algorithm, since at each iteration we solve a weighted least squares problem, where the weight matrix 
S, changes at each iteration. See Algorithm 10.1 for some pseudocode. 

Note that Fisher scoring is the same as IRLS except we replace the Hessian of the actual 
log-likelihood with its expectation, i.e., we use the Fisher information matrix (Section 4.7.2) instead 
of H. Since the Fisher information matrix is independent of the data, it can be precomputed, unlike 
the Hessian, which must be reevaluated at every iteration. This can be faster for problems with 
many parameters. 


10.2.7 MAP estimation 


In Figure 10.4, we saw how logistic regression can overfit when there are too many parameters 
compared to training examples. This is a consequence of the ability of maximum likelihood to find 
weights that force the decision boundary to “wiggle” in just the right way so as to curve around the 
examples. To get this behavior, the weights often need to be set to large values. For example, in 
Figure 10.4, when we use degree K = 1, we find that the MLE for the two input weights (ignoring 
the bias) is 


w = [0.51291712, 0.11866937| (10.44) 
When we use degree K = 2, we get 

tw = [2.27510513, 0.05970325, 11.84198867, 15.40355969, 2.51242311] (10.45) 
And when K = 4, we get 


w = [—3.07813766, --- , —59.03196044, 51.77152431, 10.25054164] (10.46) 


One way to reduce such overfitting is to prevent the weights from becoming so large. We can do 
this by using a zero-mean Gaussian prior, p(w) = N (w|0, CT), and then using MAP estimation, as 
we discussed in Section 4.5.3. The new training objective becomes 


L(w) = NLL(w) + Aj|w||3 (10.47) 
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Figure 10.6: Weight decay with variance C applied to two-class, two-dimensional logistic regression problem 
with a degree 4 polynomial. (a) C = 1. (b) C = 316. (c) C = 100,000. (d) Train and test error vs C. 
Generated by logreg_poly_demo.ipynb. 


where ||w||3 = $F; w2 and à = 1/C. This is called ¢, regularization or weight decay. The 
larger the value of À, the more the parameters are penalized for being “large” (deviating from the 
zero-mean prior), and thus the less flexible the model. See Figure 10.6 for an illustration. 

We can compute the MAP estimate by slightly modifying the input to the above gradient-based 
optimization algorithms. The gradient and Hessian of the penalized negative log likelihood have the 
following forms: 


PNLL(w) = NLL(w) + \w'w (10.48) 
VwPNLL(w) = g(w) + 2Aw (10.49) 
V2,PNLL(w) = H(w) + 2I (10.50) 


where g(w) is the gradient and H(w) is the Hessian of the unpenalized NLL. 
For an interesting exercise related to 2 regularized logistic regression, see Exercise 10.2. 
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10.2.8 Standardization 


In Section 10.2.7, we use an isotropic prior N’(w|0,~'I) to prevent overfitting. This implicitly 
encodes the assumption that we expect all weights to be similar in magnitude, which in turn encodes 
the assumption we expect all input features to be similar in magnitude. However, in many datasets, 
input features are on different scales. In such cases, it is common to standardize the data, to ensure 
each feature has mean 0 and variance 1. We can do this by subtracting the mean and dividing by 
the standard deviation of each feature, as follows: 


sondado a e (10.51) 
N 

fia = 2 Tnd (10.52) 
1 N 

64 = N 2 (ena — fla)? (10.53) 


An alternative is to use min-max scaling, in which we rescale the inputs so they lie in the interval 
(0, 1]. Both methods ensure the features are comparable in magnitude, which can help with model 
fitting and inference, even if we don’t use MAP estimation. (See Section 11.7.5 for a discussion of 
this point.) 


10.3 Multinomial logistic regression 
Multinomial logistic regression is a discriminative classification model of the following form: 
p(y|x, 0) = Cat(y|softmax(W-a + b)) (10.54) 


where æ € R? is the input vector, y € {1,...,C} is the class label, softmax() is the softmax function 
(Section 2.5.2), W is a C x D weight matrix, b is C-dimensional bias vector, 8 = (W, b) are all the 
parameters. (We will henceforth assume we have prepended each æ with a 1, and added b to the first 
column of W, so this simplifies to @ = W.) 
If we let a = Wz be the C-dimensional vector of logits, then we can rewrite the above as follows: 
Ce T (10.55) 
= CS ee at, . 
Žozi ere 


Because of the normalization condition DA P(Yn = Clan, 0) = 1, we can set wc = 0. (For example, 
in binary logistic regression, where C = 2, we only learn a single weight vector.) Therefore the 
parameters @ correspond to a weight matrix W of size (C — 1) x D, where æn € RP. 

Note that this model assumes the labels are mutually exclusive, i.e., there is only one true label. For 
some applications (e.g., image tagging), we want to predict one or more labels for an input; in this 
case, the output space is the set of subsets of {1,...,C}. This is called multi-label classification, 
as opposed to multi-class classification. This can be viewed as a bit vector, Y = {0,1}°%, where 
the c’th output is set to 1 if the c’th tag is present. We can tackle this using a modified version of 
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Figure 10.7: Example of 3-class logistic regression with 2d inputs. (a) Original features. (b) Quadratic 
features. Generated by logreg_ multiclass_ demo.ipynb. 


binary logistic regression with multiple outputs: 


c 
plylæz, 0) = | | Ber(yelo(wlæ)) (10.56) 
c=1 
10.3.1 Linear and nonlinear classifiers 


Logistic regression computes linear decision boundaries in the input space, as shown in Figure 10.7(a) 
for the case where x € R? and we have C = 3 classes. However, we can always transform the inputs 
in some way to create nonlinear boundaries. For example, suppose we replace % = (#1, £2) by 


p(x) = [L, £1, £2, £7, £2, 1102] (10.57) 


This lets us create quadratic decision boundaries, as illustrated in Figure 10.7(b). 


10.3.2 Maximum likelihood estimation 


In this section, we discuss how to compute the maximum likelihood estimate (MLE) by minimizing 
the negative log likelihood (NLL). 


10.3.2.1 Objective 
The NLL is given by 


N 


N C 
NLL(@) = TF L log iil Il ye = 2 dW Yne log Hnc = y2" ce (Yn, Hn) (10.58) 


n=1c=1 
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where Une = P(Yne = lan, 0) = softmax(f(an,A))c, Yn is the one-hot encoding of the label (so 
Yne = I (yn =c)), and Hee (Yn, Hn) is the cross-entropy: 


Cc 
Hee(p, q) = — X pelog de (10.59) 
c=1 


10.3.2.2 Optimizing the objective 


To find the optimum, we need to solve V,,NLL(w) = 0, where w is a vectorized version of the weight 
matrix W, and where we are ignoring the bias term for notational simplicity. We can find such a 
stationary point using any gradient-based optimizer; we give some examples below. But first we 
derive the gradient and Hessian, and then prove that the objective is convex. 


10.3.2.3 Deriving the gradient 


To derive the gradient of the NLL, we need to use the Jacobian of the softmax function, which is as 
follows (see Exercise 10.1 for the proof): 


Ole _ o 
da; HMe(Sej — Hy) (10.60) 


where 6,; = I (c = j). For example, if we have 3 classes, the Jacobian matrix is given by 


Atte p(l- m) = —Miple —H1H3 
Fa =| Hei pe(l—p2) -H243 (10.61) 
nes = 311 —u342  p3(1— ps) 


In matrix form, this can be written as 


E = (w1') © (I — 1p") (10.62) 


where © is elementwise product, u1" copies ys across each column, and 1p! copies p across each 
row. 

We now derive the gradient of the NLL for a single example, indexed by n. To do this, we flatten 
the D x C weight matrix into a vector w of size CD (or (C — 1)D if we freeze one of the classes to 
have zero weight) by concatenating the rows, and then transposing into a column vector. We use wj 
to denote the vector of weights associated with class 7. The gradient wrt this vector is giving by the 
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following (where we use the Kronecker delta notation, ôje, which equals 1 if j = c and 0 otherwise): 


CO ONLLy Hnc O0nj 
Vw, NLLn = 2 A e (10.63) 
=— Do Yne i Hne Sie — Hnj )&n (10.64) 


= aoe (ing — Sje)@n (10.65) 
E 5 Ync)UnjEn — >) ÕjcYnjEn (10.66) 


= (nj — Ynj)En (10.67) 


We can repeat this computation for each class, to get the full gradient vector. The gradient of the 
overall NLL is obtained by summing over examples, to give the D x C matrix 


nD En (bn — Yn) (10.68) 


This has the same form as in the binary logistic regression case, namely an error term times the 
input. 
10.3.2.4 Deriving the Hessian 


Exercise 10.1 asks you to show that the Hessian of the NLL for multinomial logistic regression is 
given by 


N 
pm diag(u,) — Hn Hn) ® (En£n) (10.69) 


where A @ B is the Kronecker product (Section 7.2.5). In other words, the block c,c’ submatrix is 
given by 


H,, el oe Lne(6 cc! — Hn, of! ) En ky (10.70) 


For example, if we have 3 features and 2 classes, this becomes 


nini Lnitn2 Lnitn3 
) ® Tn2en1 Ln2Tn2 TLn2Tn3 (10.71) 


HniHn2 Hn2 7 H 2 
tt Tn3tn1 TLn3Tn2 Mn3tn3 


Has D = SHa Hn Hn 


Hni — p24) )Xn —HUniHn2Xn 
= — 10.72 
2 4 — Uni Hn2Xn (Mn2 x TA ( ) 
where X, = £n£„. Exercise 10.1 also asks you to show that this is a positive definite matrix, so the 


objective is convex. 
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10.3.3 Gradient-based optimization 


It is straightforward to use the gradient in Section 10.3.2.3 to derive the SGD algorithm. Similarly, 
we can use the Hessian in Section 10.3.2.4 to derive a second-order optimization method. However, 
computing the Hessian can be expensive, so it is common to approximate it using quasi-Newton 
methods, such as limited memory BFGS. (BFGS stands for Broyden, Fletcher, Goldfarb and Shanno.) 
See Section 8.3.2 for details. Another approach, which is similar to IRLS, is described in Section 10.3.4. 
All of these methods rely on computing the gradient of the log-likelihood, which in turn requires 
computing normalized probabilities, which can be computed from the logits vector a = Wa using 


p(y = cl) = exp(a, — lse(a)) (10.73) 


where lse is the log-sum-exp function defined in Section 2.5.4. For this reason, many software libraries 
define a version of the cross-entropy loss that takes unnormalized logits as input. 


10.3.4 Bound optimization 


In this section, we consider an approach for fitting logistic regression using a class of algorithms 
known as bound optimization, which we describe in Section 8.7. The basic idea is to iteratively 
construct a lower bound on the function you want to maximize, and then to update the bound, so it 
“pushes up” on the true function. Optimizing the bound is often easier than updating the function 
directly. 

If £(8) is a concave function we want to maximize, then one way to obtain a valid lower bound is 
to use a bound on its Hessian, i.e., to find a negative definite matrix B such that H(@) > B. In this 
case, one can show that 


£(0) > (0°) + (0 — 6°)" g(6*) + EC — 6*)'B(0 — 6°) (10.74) 


where g(0*) = V4(6"). Defining Q(0,0*) as the right-hand-side of Equation (10.74), the update 
becomes 


0+! = 6t — B-19(6") (10.75) 


This is similar to a Newton update, except we use B, which is a fixed matrix, rather than H(6*‘), 
which changes at each iteration. This can give us some of the advantages of second order methods at 
lower computational cost. 

Let us now apply this to logistic regression, following [Kri+05], Let u„(w) = [plyn = llan, w),.--,P(Yn = 
Clan, w)| and yn = [I (yn = 1),.-.,1 (yn = C)]. We want to maximize the log-likelihood, which is 
as follows: 


N C C 
L(w) = 5 5 YncWl En z log $` exp(wl £n) (10.76) 
c=1 


n=1 Le=1 
The gradient is given by the following (see Section 10.3.2.3 for details of the derivation): 


N 
g(w) = X (Yn — Hp (w)) @ En (10.77) 


n=1 
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where @ denotes Kronecker product (which, in this case, is just outer product of the two vectors). 
The Hessian is given by the following (see Section 10.3.2.4 for details of the derivation): 


N 
H(w) = — $ (diag(u,(w)) — un (w) a, (w)") @(@nay) (10.78) 


n=1 


We can construct a lower bound on the Hessian, as shown in [Boh92]: 


N 
H(w) > -5E - 11" /C]& za!) +B (10.79) 


where I is a C-dimensional identity matrix, and 1 is a C-dimensional vector of all 1s.! In the binary 
case, this becomes 


N 
H(w) > -5 (1 = 5) (>: veh) = -7xx (10.80) 


This follows since un < 0.5 so — (un — u2) > —0.25. 
We can use this lower bound to construct an MM algorithm to find the MLE. The update becomes 


wt! = wt — Bo 'g(w’) (10.81) 


This iteration can be faster than IRLS (Section 10.2.6) since we can precompute B~! in time 
independent of N, rather than having to invert the Hessian at each iteration. For example, let us 
consider the binary case, so g! = Vé(w’) = X! (y — ut), where ut = [pa (wt), (1 — pn(w’))|-_). The 
update becomes 


wt! = wt — 4(K™X)— 1g? (10.82) 
Compare this to Equation (10.37), which has the following form: 
wt! = wt —H'g(w’) = w — (K'S'XK) “19° (10.83) 


where St = diag(p’ ©(1 — u*)). We see that Equation (10.82) is faster to compute, since we can 
precompute the constant matrix (X™X)7!. 


10.3.5 MAP estimation 


In Section 10.2.7 we discussed the benefits of 4 regularization for binary logistic regression. These 
benefits hold also in the multi-class case. However, there is also an additional, and surprising, benefit 
to do with identifiability of the parameters, as pointed out in [HTF09, Ex.18.3]. (We say that the 
parameters are identifiable if there is a unique value that maximizes the likelihood; equivalently, we 
require that the NLL be strictly convex.) 


1. If we enforce that wo = 0, we can use C — 1 dimensions for these vectors / matrices. 
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To see why identifiability is an issue, recall that multiclass logistic regression has the form 
exp(wi x) 
C 

X k= exP(w; 2) 


where W is a C x D weight matrix. We can arbitrarily define we = 0 for one of the classes, say 
c = C, since p(y = Cla, W) = 1 — yo p(y = c|xz, w). In this case, the model has the form 


ply = cx, W) = (10.84) 


exp(w; z) 


ply = cla, W) = — 
1+ yy exp(wfx) 


(10.85) 


If we don’t “clamp” one of the vectors to some constant value, the parameters will be unidentifiable. 
However, suppose we don’t clamp we = 0, so we are using Equation 10.84, but we add %2 
regularization by optimizing 


N C 
PNLL(W) = — X` log p(yn|an, W) +A X llwell? (10.86) 
c=1 


n=1 


where we have absorbed the 1/N term into À. At the optimum we have = Wej = 0 for 7 =1: D, 
so the weights automatically satisfy a sum-to-zero constraint, thus making them uniquely identifiable. 
To see why, note that at the optimum we have 


VNLL(w) + 2\w = 0 (10.87) 
So (Yn — Mn) B En = Aw (10.88) 


Hence for any feature dimension j we have 
ASS Wej = 5 X Une E Lne)€nj = OD Yne — y: lne) = 5a = 1)£nj =0 (10.89) 


Thus if \ > 0 we have „Ùc; = 0, so the weights will sum to zero across classes for each feature 
dimension. 


10.3.6 Maximum entropy classifiers 


Recall that the multinomial logistic regression model can be written as 


X wia ex wia 
ply = ex, W) = oe = sey z (10.90) 


c 


where Z(w,x) = >>, exp(w!a) is the partition function (normalization constant). This uses the 
same features, but a different weight vector, for every class. There is a slight extension of this model 
that allows us to use features that are class-dependent. This model can be written as 


p(y = cz, w) = exp(w! (a, c)) (10.91) 


=o 
Z(w, x) 
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Figure 10.8: A simple example of a label hierarchy. Nodes within the same ellipse have a mutual exclusion 
relationship between them. 


where (z,c) is the feature vector for class c. This is called a maximum entropy classifer, or 
maxent classifier for short. (The origin of this term is explained in Section 3.4.4.) 

Maxent classifiers include multinomial logistic regression as a special case. To see this let w = 
[wi,..., Wc], and define the feature vector as follows: 


(a,c) =(0,...,a,...,0] (10.92) 


where x is embedded in the c’th block, and the remaining blocks are zero. In this case, w! (æ, c) = 
wa, so we recover multinomial logistic regression. 

Maxent classifiers are very widely used in the field of natural language processing. For example, 
consider the problem of semantic role labeling, where we classify a word x into a semantic role y, 


such as person, place or thing. We might define (binary) features such as the following: 


g1(a, y) = I (y = person A x occurs after “Mr.” or “Mrs”) (10.93) 
go(a, y) = I (y = person A g is in whitelist of common names) (10.94) 
a(x, y) = I (y = place A^ x is in Google maps) (10.95) 


We see that the features we use depend on the label. 

There are two main ways of creating these features. The first is to manually specify many possibly 
useful features using various templates, and then use a feature selection algorithm, such as the group 
lasso method of Section 11.4.7. The second is to incrementally add features to the model, using a 
heuristic feature generation method. 


10.3.7 Hierarchical classification 


Sometimes the set of possible labels can be structured into a hierarchy or taxonomy. For example, 
we might want to predict what kind of an animal is in an image: it could be a dog or a cat; if it is a 
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dog, it could be a golden retriever or a German shepherd, etc. Intuitively, it makes sense to try to 
predict the most precise label for which we are confident [Den+12], that is, the system should “hedge 
its bets”. 

One simple way to achieve this, proposed in [RF 17], is as follows. First, create a model with 
a binary output label for every possible node in the tree. Before training the model, we will use 
label smearing, so that a label is propagated to all of its parents (waypernyms). For example, if 
an image is labeled “golden retriever”, we will also label it “dog”. If we train a multi-label classifier 
(which produces a vector p(y|a) of binary labels) on such smeared data, it will perform hierarchical 
classification, predicting a set of labels at different levels of abstraction. 

However, this method could predict “golden retriever”, “cat” and “bird” all with probability 1.0, 
since the model does not capture the fact that some labels are mutually exclusive. To prevent this, 
we can add a mutual exclusion constraint between all label nodes which are siblings, as shown in 
Figure 10.8. For example, this model enforces that p(mammal|a) + p(bird|a) = 1, since these two 
labels are children of the root node. We can further partition the mammal probability into dogs and 
cats, so we have p(dog|x) + p(cat|a) = p(mammall|a). 

[Den+14; Din+15] generalize the above method by using a conditional graphical model where 
the graph structure can be more complex than a tree. In addition, they allow for soft constraints 
between labels, in addition to hard constraints. 


10.3.8 Handling large numbers of classes 


In this section, we discuss some issues that arise when there are a large number of potential labels, 
e.g., if the labels correspond to words from a language. 


10.3.8.1 Hierarchical softmax 


In regular softmax classifiers, computing the normalization constant, which is needed to compute 
the gradient of the log likelihood, takes O(C) time, which can become the bottleneck if C is large. 
However, if we structure the labels as a tree, we can compute the probability of any label in O(log C) 
time, by multiplying the probabilities of each edge on the path from the root to the leaf. For example, 
consider the tree in Figure 10.9. We have 


p(y = I’m|C) = 0.57 x 0.68 x 0.72 = 0.28 (10.96) 


Thus we replace the “flat” output softmax with a tree-structured sequence of binary classifiers. This 
is called hierarchical softmax [Goo01; MB05]. 

A good way to structure such a tree is to use Huffman encoding, where the most frequent labels 
are placed near the top of the tree, as suggested in [Mik+ 13a]. (For a different appproach, based on 
clustering the most common labels together, see [Gra+17]. And for yet another approach, based on 
sampling labels, see [Tit 16].) 


10.3.8.2 Class imbalance and the long tail 


Another issue that often arises when there are a large number of classes is that for most classes, we 
may have very few examples. More precisely, if N, is the number of examples of class c, then the 
empirical distribution p(N,,..., Nc) may have a long tail. The result is an extreme form of class 
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Context C Context C 


p(w | C) i | p(w | C) 
E = | E = | 


“What? “I’m” “Horse” “Why? “Huh” “No” “Yes” “Sup” “What” “Im” “Horse” “Why” “Huh” “No” “Yes” “Sup” 


(a) (b) 


Figure 10.9: A flat and hierarchical softmax model p(w|C), where C are the input features (context) and w is 
the output label (word). Adapted from https: // www. quora. com/What-is-hierarchical-softmaz. 


imbalance (see e.g., [ASR15]). Since the rare classes will have a smaller effect on the overall loss 
than the common classes, the model may “focus its attention” on the common classes. 

One method that can help is to set the bias terms b such that softmax(b). = N./N; such a model 
will match the empirical label prior even when using weights of w = 0. We can then “subtract off” 
the prior term by using logit adjustment [Men-+21], which ensures good performance across all 
groups. 

Another common approach is to resample the data to make it more balanced, before (or during) 
training. In particular, suppose we sample a datapoint from class c with probability 


NE 


Pe 


If we set q = 1, we recover standard instance-balanced sampling, where pe x Ne; the common 
classes will be sampled more than rare classes. If we set q = 0, we recover class-balanced sampling, 
where pe = 1/C; this can be thought of as first sampling a class uniformly at random, and then 
sampling an instance of this class. Finally, we can consider other options, such as q = 0.5, which is 
known as square-root sampling [Mah-+18]. 

Yet another method that is simple and can easily handle the long tail is to use the nearest class 
mean classifier. This has the form 


f(w) = argmin ||æ — pell (10.98) 


where u, = Ni aise £n is the mean of the features belonging to class c. This induces a softmax 
posterior, as we discussed in Section 9.2.5. We can get much better results if we first use a neural 
network (see Part III) to learn good features, by training a DNN classifier with cross-entropy loss 
on the original unbalanced data. We then replace æ with (a) in Equation (10.98). This simple 
approach can give very good performance on long-tailed distributions [Kan+20]. 
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Figure 10.10: (a) Logistic regression on some data with outliers (denoted by x). Training points have been 
(vertically) jittered to avoid overlapping too much. Vertical line is the decision boundary, and its posterior 
credible interval. (b) Same as (a) but using robust model, with a mixture likelihood. Adapted from Figure 4.13 
of [Mar18]. Generated by logreg_iris_ bayes_robust_ 1d_ pymc8.ipynb. 


10.4 Robust logistic regression * 


Sometimes we have outliers in our data, which are often due to labeling errors, also called label 
noise. To prevent the model from being adversely affected by such contamination, we will use 
robust logistic regression. In this section, we discuss some approaches to this problem. (Note 
that the methods can also be applied to DNNs. For a more thorough survey of label noise, and how 
it impacts deep learning, see [Han+20].) 


10.4.1 Mixture model for the likelihood 


One of the simplest ways to define a robust logistic regression model is to modify the likelihood so 
that it predicts that each output label y is generated uniformly at random with probability 7, and 
otherwise is generated using the usual conditional model. In the binary case, this becomes 


p(y|a) = mBer(y|0.5) + (1 — 2)Ber(y|o(w'a)) (10.99) 


This approach, of using a mixture model for the observation model to make it robust, can be applied 
to many different models (e.g., DNNs). 

We can fit this model using standard methods, such as SGD or Bayesian inference methods such 
as MCMC. For example, let us create a “contaminated” version of the 1d, two-class Iris dataset that 
we discussed in Section 4.6.7.2. We will add 6 examples of class 1 (Versicolor) with abnormally 
low sepal length. In Figure 10.10a, we show the results of fitting a standard (Bayesian) logistic 
regression model to this dataset. In Figure 10.10b, we show the results of fitting the above robust 
model. In the latter case, we see that the decision boundary is similar to the one we inferred from 
non-contaminated data, as shown in Figure 4.20b. We also see that the posterior uncertainty about 
the decision boundary’s location is smaller than when using a non-robust model. 
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10.4.2 Bi-tempered loss 


In this section, we present an approach to robust logistic regression proposed in [Ami+19]. 

The first observation is that examples that are far from the decision boundary, but mislabeled, will 
have undue adverse affect on the model if the loss function is convex [LS10]. This can be overcome by 
replacing the usual cross entropy loss with a “tempered” version, that uses a temperature parameter 
0 < tı < 1 to ensure the loss from outliers is bounded. In particular, consider the standard relative 
entropy loss function: 


L(y, 9) = Hely, 9) = X ye log Ge (10.100) 


where y is the true label distribution (often one-hot) and % is the predicted distribution. We define 
the tempered cross entropy loss as follows: 


Lyd) => (lo, Yo — logy, Ge) 


Cc 


spur a) (10.101) 


which simplifes to the following when the true distribution y is one-hot, with all its mass on class c: 


C 
; - 1 „2—t 
L(c, 9) = — log, be — 5 ( - 2% | (10.102) 
Here log, is tempered version of the log function: 


log, (x) £ T ae —1) (10.103) 
This is mononotically increasing and concave, and reduces to the standard (natural) logarithm when 
t = 1. (Similarly, tempered cross entropy reduces to standard cross entropy when t = 1.) However, 
the tempered log function is bounded from below by —1/(1 — t) for 0 < t < 1, and hence the cross 
entropy loss is bounded from above (see Figure 10.11). 

The second observation is that examples that are near the decision boundary, but mislabeled, need 
to use a transfer function (that maps from activations R? to probabilities [0,1]°) that has heavier 
tails than the softmax, which is based on the exponential, so it can “look past” the neighborhood of 
the immediate examples. In particular, the standard softmax is defined by 


C 

A Ae 

de = <a = exp |. — log > a) (10.104) 
Xoz €XP(aer) cf=1 


where a is the logits vector. We can make a heavy tailed version by using the tempered softmax, 
which uses a temperature parameter tg > 1 > tı as follows: 


Yo = EXPz, (de — At, (@)) (10.105) 
where 
exp,(z) £ [1+ (1—d)a]/0- (10.106) 
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Figure 10.11: (a) Illustration of logistic and tempered logistic loss with tı = 0.8. (b) Illustration of sigmoid 
and tempered sigmoid transfer function with tg = 2.0. From https: //at. googleblog. com/2019/ 08/ 
bi-tempered-logistic-loss-for-training. html. Used with kind permission of Ehsan Amid. 


is a tempered version of the exponential function. (This reduces to the standard exponental function 
as t > 1.) In Figure 10.11 (right), we show that the tempered softmax (in the two-class case) has 
heavier tails, as desired. 


All that remains is a way to compute As, (a). This must satisfy the following fixed point equation: 


S exp, (ae — A(a)) = 1 (10.107) 
We can solve for A using binary search, or by using the iterative procedure in Algorithm 10.2. 


Algorithm 10.2: Iterative algorithm for computing A(a) in Equation (10.107). From [AWS19]. 


Input: logits a, temperature t > 1 
u := max(a) 
a:=a-pU 
while a not converged do 
| Z(a) = Yc exp: (ac) 
a := Z(a)!™ (a — u1) 


Return — log, Hay + 


aura A ONB 


a 


Combining the tempered softmax with the tempered cross entropy results in a method called 
bi-tempered logistic regression. In Figure 10.12, we show an example of this in 2d. The top 
row is standard logistic regression, the bottom row is bi-tempered. The first column is clean data. 
The second column has label noise near the boundary. The robust version uses tı = 1 (standard 
cross entropy) but t2 = 4 (tempered softmax with heavy tails). The third column has label noise far 
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Figure 10.12: Illustration of standard and bi-tempered logistic regression on data with label noise. From 
https: //ai. googleblog. com/ 2019/ 08/ bi- tempered- logistic-loss-for-training. html. Used with 
kind permission of Ehsan Amid. 


from the boundary. The robust version uses tı = 0.2 (tempered cross entropy with bounded loss) 
but t2 = 1 (standard softmax). The fourth column has both kinds of noise; in this case, the robust 
version uses tı = 0.2 and tg = 4. 


10.5 Bayesian logistic regression * 


So far we have focused on point estimates of the parameters, either the MLE or the MAP estimate. 
However, in some cases we want to compute the posterior, p(w|D), in order to capture our uncertainty. 
This can be particularly useful in settings where we have little data, and where choosing the wrong 
decision may be costly. 

Unlike with linear regression, it is not possible to compute the posterior exactly for a logistic 
regression model. A wide range of approximate algorithms can be used,. In this section, we use one 
of the simplest, known as the Laplace approximation (Section 4.6.8.2). See the sequel to this book, 
[Mur23] for more advanced approximations. 


10.5.1 Laplace approximation 


As we discuss in Section 4.6.8.2, the Laplace approximation approximates the posterior using a 
Gaussian. The mean of the Gaussian is equal to the MAP estimate w, and the covariance is equal to 
the inverse Hessian H computed at the MAP estimate, i.e., p(w/D) ~ N(w|w,H~'), We can find 
the mode using a standard optimization method (see Section 10.2.7), and then we can use the results 
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Log-Likelihood 
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Figure 10.13: (a) Illustration of the data. (b) Log-likelihood for a logistic regression model. The line is drawn 
from the origin in the direction of the MLE (which is at infinity). The numbers correspond to 4 points in 
parameter space, corresponding to the lines in (a). (c) Unnormalized log posterior (assuming vague spherical 
prior). (d) Laplace approximation to posterior. Adapted from a figure by Mark Girolami. Generated by 
logreg_laplace_ demo.ipynb. 


from Section 10.2.3.4 to compute the Hessian at the mode. 

As an example, consider the data illustrated in Figure 10.13(a). There are many parameter settings 
that correspond to lines that perfectly separate the training data; we show 4 example lines. The 
likelihood surface is shown in Figure 10.13(b). The diagonal line connects the origin to the point in 
the grid with maximum likelihood, Wie = (8.0, 3.4). (The unconstrained MLE has ||w]|| = 00, as we 
discussed in Section 10.2.7; this point can be obtained by following the diagonal line infinitely far to 
the right.) 

For each decision boundary in Figure 10.13(a), we plot the corresponding parameter vector in 
Figure 10.13(b). These parameters values are w, = (3,1), we = (4,2), w3 = (5,3), and w4 = (7,3). 
These points all approximately satisfy w;(1)/w;(2) ~ Wymie(1)/Wmie(2), and hence are close to the 
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p(y=1|x, wMAP) 
e 


(c) (a) 


Figure 10.14: Posterior predictive distribution for a logistic regression model in 2d. (a): contours of 
ply = 1\x,Wmap). (b): samples from the posterior predictive distribution. (c): Averaging over these samples. 
(d): moderated output (probit approrimation). Adapted from a figure by Mark Girolami. Generated by 
logreg_laplace_ demo.ipynb. 


orientation of the maximum likelihood decision boundary. The points are ordered by increasing 
weight norm (3.16, 4.47, 5.83, and 7.62). 

To ensure a unique solution, we use a (spherical) Gaussian prior centered at the origin, N(w|0, 071). 
The value of ø? controls the strength of the prior. If we set o? = 0, we force the MAP estimate to be 
w = 0; this will result in maximally uncertain predictions, since all points x will produce a predictive 
distribution of the form p(y = 1|æ) = 0.5. If we set 7? = oo, the prior becomes uninformative, 
and MAP estimate becomes the MLE, resulting in minimally uncertain predictions. (In particular, 
all positively labeled points will have p(y = 1|a) = 1.0, and all negatively labeled points will have 
p(y = 1|æ) = 0.0, since the data is separable.) As a compromise (to make a nice illustration), we 
pick the value g? = 100. 

Multiplying this prior by the likelihood results in the unnormalized posterior shown in Fig- 
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ure 10.13(c). The MAP estimate is shown by the red dot. The Laplace approximation to this 
posterior is shown in Figure 10.13(d). We see that it gets the mode correct (by construction), but 
the shape of the posterior is somewhat distorted. (The southwest-northeast orientation captures 
uncertainty about the magnitude of w, and the southeast-northwest orientation captures uncertainty 
about the orientation of the decision boundary.) 

In Figure 10.14, we show contours of the posterior predictive distribution. Figure 10.14(a) shows the 
plugin approximation using the MAP estimate. We see that there is no uncertainty about the decision 
boundary, even though we are generating probabilistic predictions over the labels. Figure 10.14(b) 
shows what happens when we plug in samples from the Gaussian posterior. Now we see that there is 
considerable uncertainty about the orientation of the “best” decision boundary. Figure 10.14(c) shows 
the average of these samples. By averaging over multiple predictions, we see that the uncertainty in 
the decision boundary “splays out” as we move further from the training data. Figure 10.14(d) shows 
that the probit approximation gives very similar results to the Monte Carlo approximation. 


10.5.2 Approximating the posterior predictive 


The posterior p(w|D) tells us everything we know about the parameters of the model given the data. 
However, in machine learning applications, the main task of interest is usually to predict an output y 
given an input æ, rather than to try to understand the parameters of our model. Thus we need to 
compute the posterior predictive distribution 


p(ylw,D) = / p(ylz, w)p(w|D)dw (10.108) 


As we discussed in Section 4.6.7.1, a simple approach to this is to first compute a point estimate w 
of the parameters, such as the MLE or MAP estimate, and then to ignore all posterior uncertainty, 
by assuming p(w|D) = ô(w — wW). In this case, the above integral reduces to the following plugin 
approximation: 


p(ylae, D) = / p(yla, w)5(w — &)dw = p(yla, ô) (10.109) 


However, if we want to compute uncertainty in our predictions, we should use a non-degenerate 
posterior. It is common to use a Gaussian posterior, as we will see. But we still need to approximate 
the integral in Equation (10.108). We discuss some approaches to this below. 
10.5.2.1 Monte Carlo approximation 


The simplest approach is to use a Monte Carlo approximation to the integral. This means we 
draw S samples from the posterior, ws ~ p(w|D). and then compute 


S 

1 

ply = 1læ, D) ~ z X o(wia) (10.110) 
s=1 


10.5.2.2 Probit approximation 


Although the Monte Carlo approximation is simple, it can be slow, since we need to draw S samples 
at test time for each input x. Fortunately, if p(w|D) = N(w|u, ©), there is a simple yet accurate 
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deterministic approximation, first suggested in [SL90]. To explain this approximation, we follow 
the presentation of [Bis06, p219]. The key observation is that the sigmoid function o(a) is similar 
in shape to the Gaussian cdf (see Section 2.6.1) ®(a). In particular we have o(a) ~ ®(Aa), where 
à? = 7/8 ensures the two functions have the same slope at the origin. This is useful since we can 
integrate a Gaussian cdf wrt a Gaussian pdf exactly: 


B(Aa)N (alm, v)da = © -y =o mE x o(k(v)m) (10.111) 
(A-2 + v)? (1 + d20)3 


where we have defined 


K(v) £ (1+ rv/8)72 (10.112) 
Thus if we define a = x’ w, we have 

ply = 1x, D) x o(K(v)m) (10.113) 

m=Ela=a'p (10.114) 

v= V [a] = V [x'w] =2' £x (10.115) 


where we used Equation (2.165) in the last line. Since ® is the inverse of the probit function, we will 
call this the probit approximation. 

Using Equation (10.113) results in predictions that are less extreme (in terms of their confidence) 
than the plug-in estimate. To see this, note that 0 < K(v) < 1 and hence K(v)m < m, so o(K(v)m) is 
closer to 0.5 than o(m) is. However, the decision boundary itself will not be affected. To see this, 
note that the decision boundary is the set of points æ for which p(y = 1|x, D) = 0.5. This implies 
«(v)m = 0, which implies m = W' æ = 0; but this is the same as the decision boundary from the 
plugin estimate. Thus “being Bayesian” doesn’t change the misclassification rate (in this case), but 
it does change the confidence estimates of the model, which can be important, as we illustrate in 
Section 10.5.1. 


In the multiclass case we can use the generalized probit approximation [Gib97]: 


exp(K(Uce) Mc) 
= cæ, D) = 10.11 
ply clx, ) exp(K (Ue) Mme) ( 0 6) 
Me = Ma (10.117) 
Ve = 2 Veeck (10.118) 


where « is defined in Equation (10.112). Unlike the binary case, taking into account posterior 
covariance gives different predictions than the plug-in approach (see Exercise 3.10.3 of [RW06]). 

For further approximations of Gaussian integrals combined with sigmoid and softmax functions, 
see [Daul7]. 


10.6 Exercises 


Exercise 10.1 [Gradient and Hessian of log-likelihood for multinomial logistic regression] 


T 


a. Let uik = softmax(n;)x, where n; = w” xi. Show that the Jacobian of the softmax is 


Opik 
Oni; 


= pik (Oj — Hij) (10.119) 
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where dx; = I(k = J). 
. Hence show that the gradient of the NLL is given by 


Vt = > (ye — Hic)£i (10.120) 
Hint: use the chain rule and the fact that jD Yie = 1. 


. Show that the block submatrix of the Hessian for classes c and c’ is given by 


Heo = — Y bic(Seje? — Mie! Pil, (10.121) 


Hence show that the Hessian of the NLL is positive definite. 


Exercise 10.2 [Regularizing separate terms in 2d logistic regression *] 
(Source: Jaakkola.) 


. Consider the data in Figure 10.15a, where we fit the model p(y = 1|a, w) = o (wọ +wi21+w2r2). Suppose 
we fit the model by maximum likelihood, i.e., we minimize 


J(w) = —€(w, Ptrain) (10.122) 


where ¢(w, Drain) is the log likelihood on the training set. Sketch a possible decision boundary corre- 
sponding to w. (Copy the figure first (a rough sketch is enough), and then superimpose your answer 
on your copy, since you will need multiple versions of this figure). Is your answer (decision boundary) 
unique? How many classification errors does your method make on the training set? 


. Now suppose we regularize only the wo parameter, i.e., we minimize 
Jo(w) = Lw, Dirain) + Awe (10.123) 


Suppose A is a very large number, so we regularize wo all the way to 0, but all other parameters are 
unregularized. Sketch a possible decision boundary. How many classification errors does your method 
make on the training set? Hint: consider the behavior of simple linear regression, wo + w1£ı + w2%2 when 
= T2 = 0. 


. Now suppose we heavily regularize only the wı parameter, i.e., we minimize 
(w) = —0(w, Dirain) + Awi (10.124) 


Sketch a possible decision boundary. How many classification errors does your method make on the 
training set? 


. Now suppose we heavily regularize only the w2 parameter. Sketch a possible decision boundary. How 
many classification errors does your method make on the training set? 


Exercise 10.3 [Logistic regression vs LDA/QDA *| 


(Source: Jaakkola.) Suppose we train the following binary classifiers via maximum likelihood. 


a. 


b. 


GaussI: A generative classifier, where the class-conditional densities are Gaussian, with both covariance 
matrices set to I (identity matrix), i.e., p(a|y = c) = N (æ|u., I). We assume p(y) is uniform. 


GaussX: as for GaussI, but the covariance matrices are unconstrained, i.e., p(a|y = c) = N (æ| ue, Xe). 


c. LinLog: A logistic regression model with linear features. 
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Figure 10.15: (a) Data for logistic regression question. (b) Plot of Ùp vs amount of correlation cp for three 
different estimators. 


d. QuadLog: A logistic regression model, using linear and quadratic features (i.e., polynomial basis function 
expansion of degree 2). 


After training we compute the performance of each model M on the training set as follows: 


1< z 
L(M) = = J log plyilæ:, Ô, M) (10.125) 


i=l 


(Note that this is the conditional log-likelihood p(y|æ, 0) and not the joint log-likelihood p(y, x|@).) We now 
want to compare the performance of each model. We will write L(M) < L(M") if model M must have lower 
(or equal) log likelihood (on the training set) than M’, for any training set (in other words, M is worse than 
M’, at least as far as training set logprob is concerned). For each of the following model pairs, state whether 
L(M) < L(M"'), L(M) > L(M’), or whether no such statement can be made (i.e., M might sometimes be 
better than M’ and sometimes worse); also, for each question, briefly (1-2 sentences) explain why. 


GaussI, LinLog. 

GaussX, QuadLog. 
LinLog, QuadLog. 
GaussI, QuadLog. 


Now suppose we measure performance in terms of the average misclassification rate on the training set: 


oR a FP 


R(M) = > I(yi # G(axi)) (10.126) 


Is it true in general that L(M) > L(M") implies that R(M) < R(M’)? Explain why or why not. 
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1 1 Linear Regression 


11.1 Introduction 


In this chapter, we discuss linear regression, which is a very widely used method for predicting 
a real-valued output (also called the dependent variable or target) y € R, given a vector of 
real-valued inputs (also called independent variables, explanatory variables, or covariates) 
x € RY. The key property of the model is that the expected value of the output is assumed to be a 
linear function of the input, E [y|a] = w'a#, which makes the model easy to interpret, and easy to fit 
to data. We discuss nonlinear extensions later in this book. 


11.2 Least squares linear regression 


In this section, we discuss the most common form of linear regression model. 


11.2.1 Terminology 


The term “linear regression” usually refers to a model of the following form: 
p(y|x, 0) = N(y|wo + w' a, 07) (11.1) 


where 0 = (wo, w,o7) are all the parameters of the model. (In statistics, the parameters wo and w 
are usually denoted by 6o and 3.) 

The vector of parameters w1:p are known as the weights or regression coefficients. Each 
coefficient wg specifies the change in the output we expect if we change the corresponding input 
feature xq by one unit. For example, suppose xı is the age of a person, £2 is their education level 
(represented as a continuous number), and y is their income. Thus w, corresponds to the increase 
in income we expect as someone becomes one year older (and hence get more experience), and we 
corresponds to the increase in income we expect as someone’s education level increases by one level. 
The term wọ is the offset or bias term, and specifies the output value if all the inputs are 0. This 
captures the unconditional mean of the response, wo = E [y], and acts as a baseline. We will usually 
assume that æ is written as [1,21,...,2p], so we can absorb the offset term wo into the weight vector 
w. 

If the input is one-dimensional (so D = 1), the model has the form f(x; w) = ax + b, where b = wo 
is the intercept, and a = w; is the slope. This is called simple linear regression. If the input is 
multi-dimensional, x € R? where D > 1, the method is called multiple linear regression. If the 
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degree 1 degree 2 
15 s 15 9 


104 


0.0 2.5 5.0 75 100 12.5 15.0 17.5 20.0 0.0 2.5 5.0 75 10.0 125 150 17.5 20.0 


(a) (e) 


Figure 11.1: Polynomial of degrees 1 and 2 fit to 21 datapoints. Generated by linreg_ poly vs _ degree.ipynb. 


output is also multi-dimensional, y € R7, where J > 1, it is called multivariate linear regression, 


J 
pyle, W) = [[ Ny; \wj x, 07) (11.2) 


j=1 


See Exercise 11.1 for a simple numerical example. 
In general, a straight line will not provide a good fit to most data sets. However, we can always 
apply a nonlinear transformation to the input features, by replacing x with @(a) to get 


p(y|a, 0) = N(y|\w" g(x), o°) (11.3) 


As long as the parameters of the feature extractor ġ are fixed, the model remains linear in the 
parameters, even if it is not linear in the inputs. (We discuss ways to learn the feature extractor, and 
the final linear mapping, in Part III.) 

As a simple example of a nonlinear transformation, consider the case of polynomial regression, 
which we introduced in Section 1.2.2.2. If the input is 1d, and we use a polynomial expansion of 
degree d, we get (x) = [1,2,27,...,x%]. See Figure 11.1 for an example. (See also Section 11.5 
where we discuss splines. ) 


11.2.2 Least squares estimation 


To fit a linear regression model to data, we will minimize the negative log likelihood on the training 
set. The objective function is given by 


NLL(w, 0?) = — Soon (ses) exp (—sraltm 2 wen?) (11.4) 


i N 
=o S (un — În)? + z log(2707) (11.5) 


n=1 
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where we have defined the predicted response J, = w! £n. The MLE is the point where Vw „NLL(w, 0?) = 
0. We can first optimize wrt w, and then solve for the optimal ø. 

In this section, we just focus on estimating the weights w. In this case, the NLL is equal (up to 
irrelevant constants) to the residual sum of squares, which is given by 


N 
1 1 1 
RSS(w) = 5 X (un — wan)? = 5||Kw — yl? = 3 (Kw — y) (Xw — y) (11.6) 


3 
Il 
= 


We discuss how to optimize this below. 


11.2.2.1 Ordinary least squares 
From Equation (7.264) we can show that the gradient is given by 


VwRSS(w) = X'Xw — X'y (11.7) 
Setting the gradient to zero and solving gives 
X' Xw = X'y (11.8) 


These are known as the normal equations, since, at the optimal solution, y — Xw is normal 
(orthogonal) to the range of X, as we explain in Section 11.2.2.2. The corresponding solution wW is 
the ordinary least squares (OLS) solution, which is given by 


w = (X'X) 1 XTy (11.9) 


The quantity Xİ = (X'X)~!X" is the (left) pseudo inverse of the (non-square) matrix X (see 
Section 7.5.3 for more details). 

We can check that the solution is unique by showing that the Hessian is positive definite. In this 
case, the Hessian is given by 


8? 


Be) = au 


RSS(w) = X'X (11.10) 


If X is full rank (so the columns of X are linearly independent), then H is positive definite, since for 
any v > 0, we have 


v! (X'X)v = (Xv)! (Xv) = ||Xv||? > 0 (11.11) 


Hence in the full rank case, the least squares objective has a unique global minimum. See Figure 11.2 
for an illustration. 


11.2.2.2 Geometric interpretation of least squares 


The normal equations have an elegant geometrical interpretation, deriving from Section 7.7, as we 
now explain. We will assume N > D, so there are more observations than unknowns. (This is known 
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(a) (e) 


Figure 11.2: (a) Contours of the RSS error surface for the example in Figure 11.1a. The blue cross represents 
the MLE. (b) Corresponding surface plot. Generated by linreg_ contours sse_ plot.ipynb. 


Figure 11.3: Graphical interpretation of least squares for m = 3 equations and n = 2 unknowns when solving 
the system Aw = b. ay and az are the columns of A, which define a 2d linear subspace embedded in R?. The 
target vector b is a vector in R?; its orthogonal projection onto the linear subspace is denoted b. The line 
from b to b is the vector of residual errors, whose norm we want to minimize. 


as an overdetermined system.) We seek a vector y € RY that lies in the linear subspace spanned 
by X and is as close as possible to y, i.e., we want to find 
argmin lly — ¥llo- (11.12) 
yespan({@.1,...,0:,4 
where 2. q is the d’th column of X. Since y € span(X), there exists some weight vector w such that 


J = Wx.) +--+ + wpE. p = Xw (11.13) 


To minimize the norm of the residual, y — y, we want the residual vector to be orthogonal to every 
column of X. Hence 


a! (y — 9) =0 => X' (y —- Xw) = 0 > w = (X'X) X'y (11.14) 
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Hence our projected value of y is given by 
y = Xw = X(X'X) 'XTly (11.15) 


This corresponds to an orthogonal projection of y onto the column space of X. For example, 
consider the case where we have N = 3 training examples, each of dimensionality D = 2. The 
training data defines a 2d linear subspace, defined by the 2 columns of X, each of which is a point in 
3d. We project y, which is also a point in 3d, onto this 2d subspace, as shown in Figure 11.3. 

The projection matrix 


Proj(X) = X(x' X)-*x! (11.16) 


is sometimes called the hat matrix, since 4 = Proj(X)y. In the special case that X = x is a column 
vector, the orthogonal projection of y onto the line x becomes 


Proj(z)y = x —— (11.17) 


11.2.2.3 Algorithmic issues 
Recall that the OLS solution is 


w = Xty=(X'X) tX! y (11.18) 


However, even if it is theoretically possible to compute the pseudo-inverse by inverting XTX, we 
should not do so for numerical reasons, since X'X may be ill conditioned or singular. 

A better (and more general) approach is to compute the pseudo-inverse using the SVD. Indeed, if 
you look at the source code for the function sklearn.linear_model.fit, you will see that it uses the 
scipy.linalg.lstsq function, which in turns calls DGELSD, which is an SVD-based solver implemented 
by the LAPACK library, written in Fortran.! 

However, if X is tall and skinny (ie., N >> D), it can be quicker to use QR decomposition 
(Section 7.6.2). To do this, let X = QR, where Q'Q = I. In Section 7.7, we show that OLS is 
equivalent to solving the system of linear equations Xw = y in a way that minimizes ||Xw — y||3. 
(If N = D and X is full rank, the equations have a unique solution, and the error will be 0.) Using 
QR decomposition, we can rewrite this system of equations as follows: 


(QR)w = y (11.19) 
Q'QRw = Q'y (11.20) 
w = R! (Q'y) (11.21) 


Since R is upper triangular, we can solve this last set of equations using backsubstitution, thus 
avoiding matrix inversion. See linsys __solve_demo.ipynb for a demo. 

An alternative to the use of direct methods based on matrix decomposition (such as SVD and QR) 
is to use iterative solvers, such as the conjugate gradient method (which assumes X is symmetric 


1. Note that a lot of the “Python” scientific computing stack sits on top of source code that is written in Fortran 
or C++, for reasons of speed. This makes it hard to change the underlying algorithms. By contrast, the scientific 
computing libraries in the Julia language are written in Julia itself, aiding clarity without sacrificing speed. 
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positive definite), and the GMRES (generalized minimal residual method), that works for general 
X. (In SciPy, this is implemented by sparse.linalg.gmres.) These methods just require the ability 
to perform matrix-vector multiplications (i.e., an implementation of a linear operator), and thus 
are well-suited to problems where X is sparse or structured. For details, see e.g., [TB97]. 

A final important issue is that it is usually essential to standardize the input features before 
fitting the model, to ensure that they are zero mean and unit variance. We can do this using 
Equation (10.51). 


11.2.2.4 Weighted least squares 


In some cases, we want to associate a weight with each example. For example, in heteroskedastic 
regression, the variance depends on the input, so the model has the form 


; = wie o? T)) = ex l wie 2 
pole: 8) = Nye", o? (2) = ar exp -aat ea) (11.22) 
Thus 
p(y|a; 0) = N(y|Xw, A~*) (11.23) 


where A = diag(1/o?(x,,)). This is known as weighted linear regression. One can show that the 
MLE is given by 


w = (X'AX) 1 XTAy (11.24) 


This is known as the weighted least squares estimate. 


11.2.3 Other approaches to computing the MLE 


In this section, we discuss other approaches for computing the MLE. 


11.2.3.1 Solving for offset and slope separately 


? 


Typically we use a model of the form p(y|x,0@) = N(y|wo + w! ax, 07), where wo is an offset or “bias’ 
term. We can compute (wo, w) at the same time by adding a column of 1s to X, and the computing 
the MLE as above. Alternatively, we can solve for w and wo separately. (This will be useful later.) 
In particular, one can show that 


N -ipy 
w= (XIX) Xl ye = doen ~~ T) (En ~~ = > (Yn = y) (En a 2) (11.25) 


r 1 Ta a T3 
tbo = a ay Death = V— Be (11.26) 


where X< is the centered input matrix containing £f, = £n — T along its rows, and Ye = y — y is the 
centered output vector. Thus we can first compute w on centered data, and then estimate wo using 
J- TÜ. . 

Note that, if we write the model in the form ĝ = Bo + B' (ax —%), then we have 8 = W and 
A A Py sees A = 
Wo = Bo — B T, so Bo = y. 
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11.2.3.2 Simple linear regression (1d inputs) 


In the case of 1d (scalar) inputs, the results from Section 11.2.3.1 reduce to the following simple 
form, which may be familiar from basic statistics classes: 


Pa > (En = T) (Yn z y) Czy Oy 

— &n = = pz 11.2 
S DA E Oiz p Yor (11.27) 
Wo = sE [y] — wE [a] ~ y — ûz (11.28) 


where Css = Cov [X, X] = V [X] = o2, Cyy = Cov [Y,Y] = V [Y] = o}, Cey = Cov [X,Y], and 


Pry = Ley We will use this result below. 


TU y . 


11.2.3.3 Partial regression 


From Equation (11.27), we can compute the regression coefficient of Y on X as follows: 


ð Cr 
Ryx © ZE [YIX =a] =w = Y 


J ae (11.29) 


This is the slope of the linear prediction for Y given X. 

Now consider the case where we have 2 inputs, so Y = wọ + w1Xı + w2X2 + €, where E [e] = 0. 
One can show that the optimal regression coefficient for w1 is given by Ry x,.x,, which is the partial 
regression coefficient of Y on X4, keeping Xə constant: 


o 
wWwı = Ryxı-Xə = Iz a [Y |X = x, Xo] (11.30) 


Note that this quantity is invariant to the specific value of Xə we condition on. 

We can derive wə in a similar manner. Indeed, we can extend this to multiple input variables. In 
each case, we find the optimal coefficients are equal to the partial regression coefficients. This means 
that we can interpret the 7’th coefficient Û; as the change in output y we expect per unit change in 
input xj, keeping all the other inputs constant. 


11.2.3.4 Recursively computing the MLE 


OLS is a batch method for computing the MLE. In some applications, the data arrives in a continual 
stream, so we want to compute the estimate online, or recursively, as we discussed in Section 4.4.2. 
In this section, we show how to do this for the case of simple (1d) linear regession. 

Recall from Section 11.2.3.2 that the batch MLE for simple linear regression is given by 


û = Datei 7) ae (11.31) 
ieee (11.32) 


where C,, = Cov [X,Y] and Cza = Cov [X, X] = V [X]. 
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We now discuss how to compute these results in a recursive fashion. To do this, let us define the 
following sufficient statistics: 


n 1X (n 1% 
z Pa g =-F y (11.33) 
i i=1 
a a Ig E Pa ee E 
CH) =- De: ;— 7)’, cy) aie Ye ©) (yi — y), cw) = — ou - 7)? (11.34) 
i=1 i=1 


We can update the means online using 


1 


n —(n n n —(n 1 —(n 
grr) =g 4 an pont =F), FD =7™ + —— (yny —9™) (11.35) 


n+l 


To update the covariance terms, let us first rewrite ow as follows: 


ce) =. [oan Tiyi) +o Teg) = g” ow) - w z) (11.36) 


= Z [> nd tar g a g — yon (11.37) 
_ = b> ey nai (11.38) 
Hence 
3 iyi = nC) + ney” (11.39) 
and so 
ORD = E [entryn +O) + nH) — (n + IZOD AlD (140) 


We can derive the update for Co in a similar manner. 


See Figure 11.4 for a simple illustration of these equations in action for a 1d regression model. 
To extend the above analysis to D-dimensional inputs, the easiest approach is to use SGD. The 
resulting algorithm is called the least mean squares algorithm; see Section 8.4.2 for details. 


11.2.3.5 Deriving the MLE from a generative perspective 


Linear regression is a discriminative model of the form p(y|a). However, we can also use generative 
models for regression, by analogy to how we use generative models for classification in Chapter 9, 
The goal is to compute the conditional expectation 


f(x) = E [ylz] = Ju view = ee (11.41) 
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linregOnlineDemo 


time 


Figure 11.4: Regression coefficients over time for the 1d model in Figure 1.7a(a). Generated by linre- 
gOnlineDemo.ipynb. 


Suppose we fit p(x, y) using an MVN. The MLEs for the parameters of the joint distribution are the 
empiricial means and covariances (see Section 4.2.6 for a proof of this result): 


1 
Me = 7 DL, a (11.42) 
1 
Hy = 55 2 Yn (11.43) 
Sen == J (æn —B)(tn —)" = = XTX (11.44) 
srt = yr Ln — L)\ Lyn —L) = — c : 
N m N ° 
1 _ a la 
Loy = N a (En — Z) (Yn — Y) = Ree (11.45) 
Hence from Equation (3.28), we have 
E [yx] = ply + Ery Ere (@ — Ha) (11.46) 


We can rewrite this as E [y|xz] = wo + w' æ by defining 
wo = uy — wb, = J — wW'E (11.47) 
w = Ez! Eey = (XTX.) XT y (11.48) 


This matches the MLEs for the discriminative model as we showed in Section 11.2.3.1. Thus we see 
that fitting the joint model, and then conditioning it, yields the same result as fitting the conditional 
model. However, this is only true for Gaussian models (see Section 9.4 for further discussion of this 
point). 


11.2.3.6 Deriving the MLE for ø? 
After estimating Ùmle using one of the above methods, we can estimate the noise variance. It is easy 


to show that the MLE is given by 


a2 
Oml 


1 
— A a ON os T An\2 
— me NEL G )= N > (Yn — £,,W) (11.49) 
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Figure 11.5: Residual plot for polynomial regression of degree 1 and 2 for the functions in Figure 1.7a(a-b). 
Generated by linreg_poly_vs_ degree.ipynb. 
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Figure 11.6: Fit vs actual plots for polynomial regression of degree 1 and 2 for the functions in Figure 1.7a(a-b). 
Generated by linreg_poly_vs_ degree. ipynb. 


This is just the MSE of the residuals, which is an intuitive result. 


11.2.4 Measuring goodness of fit 


In this section, we discuss some simple ways to assess how well a regression model fits the data 
(which is known as goodness of fit). 


11.2.4.1 Residual plots 


For 1d inputs, we can check the reasonableness of the model by plotting the residuals, rn = Yn — Gn, 
vs the input £n. This is called a residual plot. The model assumes that the residuals have a 
N (0, a?) distribution, so the residual plot should be a cloud of points more or less equally above and 
below the horizontal line at 0, without any obvious trends. 

As an example, in Figure 11.5(a), we plot the residuals for the linear model in Figure 1.7a(a). We 
see that there is some curved structure to the residuals, indicating a lack of fit. In Figure 11.5(b), we 
plot the residuals for the quadratic model in Figure 1.7a(b). We see a much better fit. 
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To extend this approach to multi-dimensional inputs, we can plot predictions n vs the true output 
Yn, rather than plotting vs £n. A good model will have points that lie on a diagonal line. See 
Figure 11.6 for some examples. 


11.2.4.2 Prediction accuracy and R? 


We can assess the fit quantitatively by computing the RSS (residual sum of squares) on the dataset: 
RSS(w) = SL (Yn — w'£n)?. A model with lower RSS fits the data better. Another measure that 
is used is root mean squared error or RMSE: 


RMSE(w) ê T RSS(w) (11.50) 


A more interpretable measure can be computed using the coefficient of determination, denoted 
by R?: 


DT- yn)? TSS 


where y = + a Yn is the empirical mean of the response, RSS = >> 


N (op _, y2 
R2 Aq LenailGn = yn)” _ , _ RSS (11.51) 


N (Yn — Ñn)? is the residual 
sum of squares, and TSS = E (Un — 7y)? is the total sum of squares. Thus we see that R? measures 
the variance in the predictions relative to a simple constant prediction of n = y. If a model does no 
better at predicting than using the mean of the output, it we have R? = 0. If the model perfectly fits 
the data, then the RSS will be 0, so R? = 1. In general, larger values imply greater reduction in 


variance (better fit). This is illustrated in Figure 11.6. 


11.3 Ridge regression 


Maximum likelihood estimation can result in overfitting, as we discussed in Section 1.2.2.2. A 
simple solution to this is to use MAP estimation with a zero-mean Gaussian prior on the weights, 
p(w) = N(w|0,\—'I), as we discused in Section 4.5.3. This is called ridge regression. 

In more detail, we compute the MAP estimate as follows: 


1 1 
Wmap = argmin >> (y — Xw)' (y - Xw) + -5u'w (11.52) 
= argmin RSS(w) + Aj|w||5 (11.53) 


where \ £ z is proportional to the strength of the prior, and 


(11.54) 


is the l2 norm of the vector w. Thus we are penalizing weights that become too large in magnitude. 
In general, this technique is called 2 regularization or weight decay, and is very widely used. 
See Figure 4.5 for an illustration. 

Note that we do not penalize the offset term wọ, since that only affects the global mean of the 
output, and does not contribute to overfitting. See Exercise 11.2. 
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11.3.1 Computing the MAP estimate 


In this section, we discuss algorithms for computing the MAP estimate. 
The MAP estimate corresponds to minimizing the following penalized objective: 


J(w) = (y - Xw) (y — Xw) + Al|w||5 (11.55) 
where \ = g? /7° is the strength of the regularizer. The derivative is given by 

VwJ(w) = 2(X'Xw — X'y + àw) (11.56) 
and hence 

timap = (XTX + Ap)" X'y = ($ eng, +AIp) (X ynan) (11.57) 


11.3.1.1 Solving using QR 


Naively computing the primal estimate w = (XTX + \I)~!X'y using matrix inversion is a bad idea, 
since it can be slow and numerically unstable. In this section, we describe a way to convert the 
problem to a standard least squares problem, to which we can apply QR decomposition, as discussed 
in Section 11.2.2.3. 

We assume the prior has the form p(w) = N(0, AT’), where A is the precision matrix. In the case 
of ridge regression, A = (1/7?)I. We can emulate this prior by adding “virtual data” to the training 
set to get 


z _(X/e\ ~_ (y/o ) 
x= JS 11.58 
E) al a 
T n 
where A = VAVA is a Cholesky decomposition of A. We see that X is (N + D) x D, where the 
extra rows represent pseudo-data from the prior. 


We now show that the RSS on this expanded data is equivalent to penalized RSS on the original 
data: 


f(w) = (ù — Xw)" (g - Xw) as 
> 
((%”) ~ CK) w) (es ~ CK) w) (11.60) 


_ Gaal Ge - a (11.61) 


—VAw —_/Aw 
= 4u — Xw)" (y — Xw) + (VAw)" (Vw) (11.62) 
= Sy — Xw)"(y— Xw) + w' Aw (11.63) 


Hence the MAP estimate is given by 
tmap = (XTX) XY (11.64) 
which can be solved using standard OLS methods. In particular, we can compute the QR decomposi- 


tion of X, and then proceed as in Section 11.2.2.3. This takes O((N + D)D?) time. 
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11.3.1.2 Solving using SVD 


In this section, we assume D > N, which is the usual case when using ridge regression. In this case, 
it is faster to use SVD than QR. To see how this works, let X = USV" be the SVD of X, where 
Viv =Iy, UU! = U'U = Iy, and S is a diagonal N x N matrix. Now let R = US be an N x N 
matrix. One can show (see Exercise 18.4 of [HTF09]) that 


Wmap = V(R'R + Aly) Rly (11.65) 


In other words, we can replace the D-dimensional vectors æ; with the N-dimensional vectors r; and 
perform our penalized fit as before. The overall time is now O(DN7?) operations, which is less than 
O(D3) if D> N. 


11.3.2 Connection between ridge regression and PCA 


In this section, we discuss an interesting connection between ridge regression and PCA (which we 
describe in Section 20.1), in order to gain further insight into why ridge regression works well. Our 
discussion is based on [HTF09, p66]. 

Let X = USV! be the SVD of X, where V'V = Iy, UU! = U'U = Iy, and S is a diagonal 
N x N matrix. Using Equation (11.65) we can see that the ridge predictions on the training set are 
given by 


9 = XWmap = USV' V(S? + AI) SU'y (11.66) 
D 
= USU'y= 5 ujŠjjuly (11.67) 
where 
G A 2 1 oj 
Sjj = [S(S + AI) Sli; = a a (11.68) 
and gj are the singular values of X. Hence 
D o2 
~n xh — Eo 
I = XWmap = 2s ota IY (11.69) 
In contrast, the least squares prediction is 
ð = Xtc = (USV")(VS“!U'y) = UUy = So y (11.70) 
If o? is large compared to A, then o4/(07 + A) © ae = 1, so direction wu, is not affected, but if 


a? is small compared to À, and if A is large, then o? /(o7 + A) ~ 1/A © 0, so direction uj will be 
downweighted. In view of this, we define the effective number of degrees of freedom of the model 
as follows: 


D 
dof(A) = X` 


j=1 


g: 


S.N 


(11.71) 


+ 


2 
Oj 
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Figure 11.7: Geometry of ridge regression. The likelihood is shown as an ellipse, and the prior is shown as a 
circle centered on the origin. Adapted from Figure 3.15 of [Bis06]. Generated by geom_ridge.ipynb. 


When à = 0, dof(A) = D, and as À — ov, dof(A) > 0. 

Let us try to understand why this behavior is desirable. In Section 11.7, we show that Cov [w|D] « 
(X'™X)~1, if we use a uniform prior for w. Thus the directions in which we are most uncertain 
about w are determined by the eigenvectors of (X'X)~! with the largest eigenvalues, as shown in 
Figure 7.6. These drections correspond to the eigenvectors of X'X with the smallest eigenvalues, 
and hence (from Section 7.5.2) the smallest singular values. So if o? is small relative to A, ridge 
regression will downweight direction uj. 

This process is illustrated in Figure 11.7. The horizontal w, parameter is not-well determined 
by the data (has high posterior variance), but the vertical wz parameter is well-determined. Hence 
Wmap(2) is close to Wmie(2), but Wmap(1) is shifted strongly towards the prior mean, which is 0. In 
this way, ill-determined parameters are reduced in size towards 0. This is called shrinkage. 

There is a related, but different, technique called principal components regression, which is 
a supervised version of PCA, which we explain in Section 20.1. The idea is this: first use PCA to 
reduce the dimensionality to K dimensions, and then use these low dimensional features as input to 
regression. However, this technique does not work as well as ridge regression in terms of predictive 
accuracy [HTF01, p70]. The reason is that in PC regression, only the first K (derived) dimensions 
are retained, and the remaining D — K dimensions are entirely ignored. By contrast, ridge regression 
uses a “soft” weighting of all the dimensions. 


11.3.3 Choosing the strength of the regularizer 


To find the optimal value of À, we can try a finite number of distinct values, and use cross validation 
to estimate their expected loss, as discussed in Section 4.5.5.2. See Figure 4.5d for an example. 

This approach can be quite expensive if we have many values to choose from. Fortunately, we can 
often warm start the optimization procedure, using the value of w(A;) as an initializer for (A, 41), 
where Ax,41 < Ax; in other words, we start with a highly constrained model (strong regularizer), and 
then gradually relax the constraints (decrease the amount of regularization). The set of parameters 
Ù, that we sweep out in this way is known as the regularization path. See Figure 11.10(a) for an 
example. 

We can also use an empirical Bayes approach to choose A. In particular, we choose the hyperpa- 
rameter by computing \ = argmax, log p(D|\), where p(D|A) is the marginal likelihood or evidence. 
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Figure 4.7b shows that this gives essentially the same result as the CV estimate. However, the 
Bayesian approach has several advantages: computing p(D|A) can be done by fitting a single model, 
whereas CV has to fit the same model K times; and p(D|A) is a smooth function of A, so we can use 
gradient-based optimization instead of discrete search. 


11.4 Lasso regression 


In Section 11.3, we assumed a Gaussian prior for the regression coefficients when fitting linear 
regression models. This is often a good choice, since it encourages the parameters to be small, and 
hence prevents overfitting. However, sometimes we want the parameters to not just be small, but to 
be exactly zero, i.e., we want w to be sparse, so that we minimize the LO-norm: 


D 
llwlļo = XC I (wal > 0) (11.72) 
d=1 


This is useful because it can be used to perform feature selection. To see this, note that the 
prediction has the form f(a; w) = > WaXq, so if any wa = 0, we ignore the corresponding feature 
za. (The same idea can be applied to nonlinear models, such as DNNs, by encouraging the first layer 
weights to be sparse.) 


11.4.1 MAP estimation with a Laplace prior (€; regularization) 


There are many ways to compute such sparse estimates (see e.g., [Bha+19]). In this section we focus 
on MAP estimation using the Laplace distribution (which we discussed in Section 11.6.1) as the 
prior: 


D D 


p(w|r) = II Laplace(w|0, 1/A) « Il eAlwal (11.73) 
d=1 d=1 


where A is the sparsity parameter, and 


Laplace(w|u, b) £ : exp ( w- H) (11.74) 
2b b 
Here p is a location parameter and b > 0 is a scale parameter. Figure 2.15 shows that Laplace(w]0, b) 
puts more density on 0 than N/(w|0,07), even when we fix the variance to be the same. 
To perform MAP estimation of a linear regression model with this prior, we just have to minimize 
the following objective: 


PNLL(w) = — log p(D|w) — log p(w|A) = || Xw — y||3 + Alw] (11.75) 
where ||w||1 £ Seay |wa| is the 2; norm of w. This method is called lasso, which stands for “least 


absolute shrinkage and selection operator” [Tib96]. (We explain the reason for this name below.) 
More generally, MAP estimation with a Laplace prior is called ¢;-regularization. 
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Figure 11.8: Illustration of lı (left) vs €2 (right) regularization of a least squares problem. Adapted from 
Figure 3.12 of [HTF01]. 


Note also that we could use other norms for the weight vector. In general, the g-norm is defined as 
follows: 


D 1/q 
lwll = (>. al (11.76) 
d=1 


For q < 1, we can get even sparser solutions. In the limit where q = 0, we get the fo-norm: 


D 
lwll = $ T (wal > 0) (11.77) 
d=1 


However, one can show that for any q < 1, the problem becomes non-convex (see e.g., [HTW15]). 
Thus ¢;-norm is the tightest convex relaxation of the @o-norm. 
11.4.2 Why does £ı regularization yield sparse solutions? 


We now explain why @; regularization results in sparse solutions, whereas l> regularization does not. 
We focus on the case of linear regression, although similar arguments hold for other models. 

The lasso objective is the following non-smooth objective (see Section 8.1.4 for a discussion of 
smoothness): 


min NLL(w) + \\|2w]|1 (11.78) 
This is the Lagrangian for the following quadratic program (see Section 8.5.4): 


min NLL(w) s.t. ||wl|, < B (11.79) 


where B is an upper bound on the £1-norm of the weights: a small (tight) bound B corresponds to a 
large penalty A, and vice versa. 

Similarly, we can write the ridge regression objective minw NLL(w) + A||w||3 in bound constrained 
form: 


minNLL(w) s.t. ||w||5 < B (11.80) 
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In Figure 11.8, we plot the contours of the NLL objective function, as well as the contours of the 
Lə and ¢; constraint surfaces. From the theory of constrained optimization (Section 8.5) we know 
that the optimal solution occurs at the point where the lowest level set of the objective function 
intersects the constraint surface (assuming the constraint is active). It should be geometrically clear 
that as we relax the constraint B, we “grow” the 44 “ball” until it meets the objective; the corners of 
the ball are more likely to intersect the ellipse than one of the sides, especially in high dimensions, 
because the corners “stick out” more. The corners correspond to sparse solutions, which lie on the 
coordinate axes. By contrast, when we grow the £2 ball, it can intersect the objective at any point; 
there are no “corners”, so there is no preference for sparsity. 


11.4.3 Hard vs soft thresholding 


The lasso objective has the form £(w) = NLL(w) + A||w||1. One can show (Exercise 11.3) that the 
gradient for the smooth NLL part is given by 


o 


——NLL(w) = AqdWd — Cd (11.81) 
Owa 
N 
aa = 5 is (11.82) 
n=1 
N 
Ca = X tna(tn — Wl g®n,—a) (11.83) 
n=i 


where w_—q is w without component d, and similarly £n,—a is feature vector x, without component 
d. We see that cq is proportional to the correlation between d’th column of features, æ. q, and the 
residual error obtained by predicting using all the other features, r-a = y — X,,_qw_g. Hence the 
magnitude of cą is an indication of how relevant feature d is for predicting y, relative to the other 
features and the current parameters. Setting the gradient to 0 gives the optimal update for wa, 
keeping all other weights fixed: 


T 
T. qld 


Waqa = Ca/Ga = Jæ 


(11.84) 
„all 


The corresponding new prediction for r_g becomes 7_g = wa.,a, which is the orthogonal projection 
of the residual onto the column vector æ. a, consistent with Equation (11.15). 

Now we add in the 4, term. Unfortunately, the ||w||; term is not differentiable whenever wg = 0. 
Fortunately, we can still compute a subgradient at this point. Using Equation (8.14) we find that 


Buy L(w) = (aaa — ca) + Ong lew (11.85) 
{aawa SG m A} if wa < 0 
= [—ca — À, —Ca + A] if wq = 0 (11.86) 


{aawa SCT A} if wa > 0 


Depending on the value of cg, the solution to w, £ (w) = 0 can occur at 3 different values of wa, 
as follows: 
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(a) (b) 


Figure 11.9: Left: soft thresholding. Right: hard thresholding. In both cases, the horizontal axis is the residual 
error incurred by making predictions using all the coefficients except for wk, and the vertical axis is the 
estimated coefficient Ùp that minimizes this penalized residual. The flat region in the middle is the interval 
[—A, +A]. 


1. If ca < —A, so the feature is strongly negatively correlated with the residual, then the subgradient 
is zero at Wg = carà <0. 


2. If ca € [—A, A], so the feature is only weakly correlated with the residual, then the subgradient is 
zero at Wg = 0. 


3. If cq > A, so the feature is strongly positively correlated with the residual, then the subgradient is 
zero at Wg = Sa—à >0. 


In summary, we have 


(ca + à) /aa if cg < —A 
Wa(ca) = 0 if ca € [-A, A] (11.87) 
(ca — À) /aa if ca > A 


We can write this as follows: 


tq = SoftThreshold(<“, A/aa) (11.88) 
ad 
where 
SoftThreshold(a, 5) Ê sign(«) (|a| — 5), (11.89) 


and x} = max(z, 0) is the positive part of x. This is called soft thresholding (see also Section 8.6.2). 
This is illustrated in Figure 11.9(a), where we plot tg vs cq. The dotted black line is the line 
Wa = Ca/aq corresponding to the least squares fit. The solid red line, which represents the regularized 
estimate twa, shifts the dotted line down (or up) by A, except when —A < ca < A, in which case it 
sets wg = 0. 

By contrast, in Figure 11.9(b), we illustrate hard thresholding. This sets values of wg to 0 if 
—À < cą < A, but it does not shrink the values of wg outside of this interval. The slope of the soft 
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Figure 11.10: (a) Profiles of ridge coefficients for the prostate cancer example vs bound B on b2 norm of w, 
so small B (large A) is on the left. The vertical line is the value chosen by 5-fold CV using the 1 standard 
error rule. Adapted from Figure 3.8 of [HTF09]. Generated by ridgePathProstate.ipynb. (b) Same as (a) 
but using lı norm of w. The x-axis shows the critical values of A =1/B, where the regularization path is 
discontinuous. Adapted from Figure 3.10 of [HTF09]. Generated by lassoPathProstate.ipynb. 


thresholding line does not coincide with the diagonal, which means that even large coefficients are 
shrunk towards zero. This is why lasso stands for “least absolute selection and shrinkage operator”. 
Consequently, lasso is a biased estimator (see Section 4.7.6.1). 

A simple solution to the biased estimate problem, known as debiasing, is to use a two-stage 
estimation process: we first estimate the support of the weight vector (i.e., identify which elements 
are non-zero) using lasso; we then re-estimate the chosen coefficients using least squares. For an 
example of this in action, see Figure 11.13. 


11.4.4 Regularization path 


If A = 0, we get the OLS solution. which will be dense. As we increase À, the solution vector w(A) 
will tend to get sparser. If A is bigger than some critical value, we get w = 0. This critical value is 
obtained when the gradient of the NLL cancels out with the gradient of the penalty: 


Ams = max |Vw,NLL(0)| = max ca(w =0)= max ly" z. al = XT yllo (11.90) 


Alternatively, we can work with the bound B on the 44 norm. When B = 0, we get w = 0. As we 
increase B, the solution becomes denser. The largest value of B for which any component is zero is 
given by Bmax = ||t@mte||1- 

As we increase À, the solution vector Ù gets sparser, although not necessarily monotonically. We can 
plot the values 4 vs À (or vs the bound B) for each feature d; this is known as the regularization 
path. This is illustrated in Figure 11.10(b), where we apply lasso to the prostate cancer regression 
dataset from [HTF09]. (We treat features gleason and svi as numeric, not categorical.) On the left, 
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0 0) (0 0 (0) (0 0 (0 
0.4279 (0 (0 0 0) 0) 0 0) 
0.5015 0.0735 0) 0 0 (0 0 0) 
0.5610 0.1878 (0 0 0.0930 (0 0 0) 
0.5622 0.1890 (0 0.0036 0.0963 (0 0 0) 
0.5797 0.2456 (0 0.1435 0.2003 (0 (0 0.0901 
0.5864 0.2572 -0.0321 0.1639 0.2082 (0 (0 0.1066 
0.6994 0.2910 -0.1337 0.2062 0.3003 -0.2565 0 0.2452 
0.7164 0.2926 -0.1425 0.2120 0.3096 -0.2890 -0.0209 0.2773 


Table 11.1: Values of the coefficients for linear regression model fit to prostate cancer dataset as we vary the 
strength of the £, regularizer. These numbers are plotted in Figure 11.10(b). 


when B = 0, all the coefficients are zero. As we increase B, the coefficients gradually “turn on”.? The 
analogous result for ridge regression is shown in Figure 11.10(a). For ridge, we see all coefficients are 
non-zero (assuming A > 0), so the solution is not sparse. 

Remarkably, it can be shown that the lasso solution path is a piecewise linear function of A [Efr+04; 
GL15]. That is, there are a set of critical values of A where the active set of non-zero coefficients 
changes. For values of À between these critical values, each non-zero coefficient increases or decreases 
in a linear fashion. This is illustrated in Figure 11.10(b). Furthermore, one can solve for these critical 
values analytically [Efr+04]. In Table 11.1. we display the actual coefficient values at each of these 
critical steps along the regularization path (the last line is the least squares solution). 

By changing À from Amax to 0, we can go from a solution in which all the weights are zero to a 
solution in which all weights are non-zero. Unfortunately, not all subset sizes are achievable using 
lasso. In particular, one can show that, if D > N, the optimal solution can have at most N variables 
in it, before reaching the complete set corresponding to the OLS solution of minimal 44 norm. In 
Section 11.4.8, we will see that by using an 4> regularizer as well as an 44 regularizer (a method 
known as the elastic net), we can achieve sparse solutions which contain more variables than training 
cases. This lets us explore model sizes between N and D. 


11.4.5 Comparison of least squares, lasso, ridge and subset selection 


In this section, we compare least squares, lasso, ridge and subset selection. For simplicity, we assume 
all the features of X are orthonormal, so X'X =I. In this case, the NLL is given by 


NLL(w) = ||y — Xw||? = y'y + w'X'Xw — 2w' X'y (11.91) 
= const + 5 w —2 5 5 WdEndYn (11.92) 
d d t 


so we see this factorizes into a sum of terms, one per dimension. Hence we can write down the MAP 
and ML estimates analytically for each wa separately, as given below. 


e MLE From Equation (11.85), the OLS solution is given by 
amie = ca/Qa — xy (11.93) 


where 2. is the d’th column of X. 


2. It is common to plot the solution versus the shrinkage factor, defined as s(B) = B/Bmax, rather than against B. 
This merely affects the scale of the horizontal axis, not the shape of the curves. 
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Term OLS Best Subset Ridge Lasso 


intercept 2.465 2.477 2.467 2.465 
Icalvol 0.676 0.736 0.522 0.548 
lweight 0.262 0.315 0.255 0.224 

age -0.141 0.000 -0.089 0.000 
Ibph 0.209 0.000 0.186 0.129 
svi 0.304 0.000 0.259 0.186 
lep -0.287 0.000 -0.095 0.000 
gleason -0.021 0.000 0.025 0.000 
pgg45 0.266 0.000 0.169 0.083 
Test error 0.521 0.492 0.487 0.457 
Std error 0.176 0.141 0.157 0.146 


Figure 11.11: Results of different methods on the prostate cancer data, which has 8 features and 67 training 
cases. Methods are: OLS = ordinary least squares, Subset = best subset regression, Ridge, Lasso. Rows 
represent the coefficients; we see that subset regression and lasso give sparse solutions. Bottom row is 
the mean squared error on the test set (30 cases). Adapted from Table 3.8. of [HTF09]. Generated by 
prostate_ comparison.ipynb. 


e Ridge One can show that the ridge estimate is given by 


gridee — Tad (11.94) 


e Lasso From Equation (11.88), and using the fact that w'"!° = cq/aa, we have 


«lasso 


we? = sign(adz"’) (07| — A), (11.95) 
This corresponds to soft thresholding, shown in Figure 11.9(a). 


e Subset selection If we pick the best K features using subset selection, the parameter estimate 
is as follows 


amle 3 jmle]) < 
oe = { wre if rank(|w’|) < K (11.96) 


0 otherwise 


where rank refers to the location in the sorted list of weight magnitudes. This corresponds to 
hard thresholding, shown in Figure 11.9(b). 


We now experimentally compare the prediction performance of these methods on the prostate cancer 
regression dataset from [HTF09]. (We treat features gleason and svi as numeric, not categorical.) 
Figure 11.11 shows the estimated coefficients at the value of A (or K) chosen by cross-validation; 
we see that the subset method is the sparsest, then lasso. In terms of predictive performance, all 
methods are very similar, as can be seen from Figure 11.12. 
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o 
oo 


Ls Best Subset Ridge Lasso 


Figure 11.12: Boxplot displaying (absolute value of) prediction errors on the prostate cancer test set for 
different regression methods. Generated by prostate comparison.ipynb. 


11.4.6 Variable selection consistency 


It is common to use 44 regularization to estimate the set of relevant variables, a process known as 
variable selection. A method that can recover the true set of relevant variables (i.e., the support 
of w*) in the N — co limit is called model selection consistent. (This is a theoretical notion 
that assumes the data comes from the model.) 

Let us give an example. We first generate a sparse signal w* of size D = 4096, consisting of 160 
randomly placed +1 spikes. Next we generate a random design matrix X of size N x D, where 
N = 1024. Finally we generate a noisy observation y = Xw* + e, where en ~ N(0,0.017). We then 
estimate w from y and X. The original w* is shown in the first row of Figure 11.13. The second 
row is the 4 estimate wz) using À = 0.1Amax. We see that this has “spikes” in the right places, so 
it has correctly identified the relevant variables. However, although we see that Wz has correctly 
identified the non-zero components, but they are too small, due to shrinkage. In the third row, we 
show the results of using the debiasing technique discussed in Section 11.4.3. This shows that we 
can recover the original weight vector. By contrast, the final row shows the OLS estimate, which is 
dense. Furthermore, it is visually clear that there is no single threshold value we can apply to Wie 
to recover the correct sparse weight vector. 

To use lasso to perform variable selection, we have to pick À. It is common to use cross validation 
to pick the optimal value on the regularization path. However, it is important to note that cross 
validation is picking a value of A that results in good predictive accuracy. This is not usually the same 
value as the one that is likely to recover the “true” model. To see why, recall that 4 regularization 
performs selection and shrinkage, that is, the chosen coefficients are brought closer to 0. In order to 
prevent relevant coefficients from being shrunk in this way, cross validation will tend to pick a value 
of A that is not too large. Of course, this will result in a less sparse model which contains irrelevant 
variables (false positives). Indeed, it was proved in [MBO6] that the prediction-optimal value of A 
does not result in model selection consistency. However, various extensions to the basic method have 
been devised that are model selection consistent (see e.g., [BG11; HTW15]). 
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Original (D = 4096, number of nonzeros = 160) 


1 , i , 
0 
0 1000 2000 3000 4000 
L1 reconstruction (KO = 1024, lambda = 0.0516, MSE = 0.0027) 


Debiased (MSE = 3.26e—005) 


1 T 7 T 
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0 1000 2000 3000 4000 
Minimum norm solution (MSE = 0.0292) 
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pei nie pile Ae 
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Figure 11.13: Example of recovering a sparse signal using lasso. See text for details. Adapted from Figure 1 
of [FNW07]. Generated by sparse_ sensing _ demo.ipynb. 


11.4.7 Group lasso 


In standard ¢, regularization, we assume that there is a 1:1 correspondence between parameters 
and variables, so that if wq = 0, we interpret this to mean that variable d is excluded. But in more 
complex models, there may be many parameters associated with a given variable. In particular, 
each variable d may have a vector of weights wa associated with it, so the overall weight vector has 
block structure, w = [w1, W2,..., wp]. If we want to exclude variable d, we have to force the whole 
subvector wg to go to zero. This is called group sparsity. 


11.4.7.1 Applications 


Here are some examples where group sparsity is useful: 


e Linear regression with categorical inputs: If the d’th variable is categorical with K possible levels, 
then it will be represented as a one-hot vector of length K (Section 1.5.3.1), so to exclude variable 
d, we have to set the whole vector of incoming weights to 0. 


e Multinomial logistic regression: The d’th variable will be associated with C different weights, one 
per class (Section 10.3), so to exclude variable d, we have to set the whole vector of outgoing 
weights to 0. 


e Neural networks: the k’th neuron will have multiple inputs, so if we want to “turn the neuron 
off’, we have to set all the incoming weights to zero. This allows us to use group sparsity to learn 
neural network structure (for details, see e.g., [GEH19]). 
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e Multi-task learning: each input feature is associated with C different weights, one per output task. 
If we want to use a feature for all of the tasks or none of the tasks, we should select weights at 
the group level [OTJO7]. 


11.4.7.2 Penalizing the two-norm 


To encourage group sparsity, we partition the parameter vector into G groups, w = [w1,..., Wo]. 
Then we minimize the following objective 


G 
PNLL(w) = NLL(w) +A X` ||wgll2 (11.97) 
g=1 


where ||wg||2 = 4/d acg wł is the 2-norm of the group weight vector. If the NLL is least squares, 


this method is called group lasso [YL06; Kyu+10]. 
Note that if we had used the sum of the squared 2-norms in Equation (11.97), then the model 
would become equivalent to ridge regression, since 


G 
X lwi = 5 > wa = lwli (11.98) 
g=1 


g d&g 


By using the square root, we are penalizing the radius of a ball containing the group’s weight vector: 
the only way for the radius to be small is if all elements are small. 

Another way to see why the square root version enforces sparsity at the group level is to consider 
the gradient of the objective. Suppose there is only one group of two variables, so the penalty has 
the form yw? + w2. The derivative wrt w; is 


W1 
2 2 
ywi + w5 


If wz is close to zero, then the derivative approaches 1, and w; is driven to zero as well, with force 
proportional to A. If, however, we is large, the derivative approaches 0, and w; is free to stay large 
as well. So all the coefficients in the group will have similar size. 


o 
gu (tei + wa)? = 


11. 
TA (11.99) 


11.4.7.3 Penalizing the infinity norm 
A variant of this technique replaces the 2-norm with the infinity-norm [TVW05; ZRY05]: 


Ilwoll = max [wal (11.100) 


It is clear that this will also result in group sparsity, since if the largest element in the group is forced 
to zero, all the smaller ones will be as well. 


11.4.7.4 Example 


An illustration of these techniques is shown in Figure 11.14 and Figure 11.15. We have a true signal 
w of size D = 2!* = 4096, divided into 64 groups each of size 64. We randomly choose 8 groups 
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3 Original (D = 4096, number groups = 64, active groups = 8) Standard L1 (debiased 1, tau = 0.427, MSE = 0.08415) 
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(c) (a) 


Figure 11.14: Illustration of group lasso where the original signal is piecewise Gaussian. (a) Original signal. 
(b) Vanilla lasso estimate. (c) Group lasso estimate using an £2 norm on the blocks. (d) Group lasso estimate 
using an Læ norm on the blocks. Adapted from Figures 3-4 of [WNF09]. Generated by groupLassoDemo.ipynb. 


of w and assign them non-zero values. In Figure 11.14 the values are drawn from a \V(0,1); in 
Figure 11.15, the values are all set to 1. We then sample a random design matrix X of size N x D, 
where N = 21° = 1024. Finally, we generate y = Xw + e, where e ~ N (0, 1074In). Given this data, 
we estimate the support of w using £1 or group ¢;, and then estimate the non-zero values using least 
squares (debiased estimate). 

We see from the figures that group lasso does a much better job than vanilla lasso, since it respects 
the known group structure. We also see that the Z% norm has a tendency to make all the elements 
within a block to have similar magnitude. This is appropriate in the second example, but not the 
first. (The value of À was the same in all examples, and was chosen by hand.) 
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Figure 11.15: Same as Figure 11.14, except the original signal is piecewise constant. Generated by groupLas- 
soDemo.ipynb. 


11.4.8 Elastic net (ridge and lasso combined) 


In group lasso, we need to specify the group structure ahead of time. For some problems, we don’t 
know the group structure, and yet we would still like highly correlated coefficients to be treated as an 
implicit group. One way to achieve this effect, proposed in [ZH05], is to use the elastic net, which is 
a hybrid between lasso and ridge regression. This corresponds to minimizing the following objective: 


L(w, Ar, A2) = lly — Xw||? + Az||wl||d + àille] (11.101) 


This penalty function is strictly convex (assuming A2 > 0) so there is a unique global minimum, even 
if X is not full rank. It can be shown [ZH05] that any strictly convex penalty on w will exhibit a 
grouping effect, which means that the regression coefficients of highly correlated variables tend to 


3. It is apparently called the “elastic net” because it is “like a stretchable fishing net that retains all the big fish” [ZH05]. 
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be equal. In particular, if two features are identically equal, so X.; = X.,, one can show that their 
estimates are also equal, ù; = Wy. By contrast, with lasso, we may have that ù; = 0 and tw, Æ 0 or 
vice versa, resulting in less stable estimates. 

In addition to its soft grouping behavior, elastic net has other advantages. In particular, if D > N, 
the maximum number of non-zero elements that can be selected (excluding the MLE, which has D 
non-zero elements) is N. By contrast, elastic net can select more than N non-zero variables on its 
path to the dense estimate, thus exploring more possible subsets of variables. 


11.4.9 Optimization algorithms 


A large variety of algorithms have been proposed to solve the lasso problem, and other ¢;-regularized 
convex objectives. In this section, we briefly mention some of the most popular methods. 


11.4.9.1 Coordinate descent 


Sometimes it is hard to optimize all the variables simultaneously, but it easy to optimize them one 
by one. In particular, we can solve for the j’th coefficient with all the others held fixed as follows: 


w} = argmin L(w + nej) (11.102) 
n 

where e; is the j’th unit vector. This is called coordinate descent. We can either cycle through 

the coordinates in a deterministic fashion, or we can sample them at random, or we can choose to 

update the coordinate for which the gradient is steepest. 

This method is particularly appealing if each one-dimensional optimization problem can be solved 
analytically, as is the case for lasso (see Equation (11.87)). This is known as the shooting algorithm 
[Fu98; WLO8]. (The term “shooting” is a reference to cowboy theme inspired by the term “lasso”.) 
See Algorithm 11.1 for details. 

This coordinate descent method has been generalized to the GLM case in [FHT 10], and is the 
basis of the popular glmnet software library. 


Algorithm 11.1: Coordinate descent for lasso (aka shooting algorithm) 
1 Initialize w = (X'X + AI) tX! y 

2 repeat 

3 for d= 1,..., D do 

aa = D zad 

Ca = ee End(Yn — w' En + Wand) 

wa = SoftThreshold(“, A/aq) 


an A 


7 until converged 


11.4.9.2 Projected gradient descent 


In this section, we convert the non-differentiable 4/4 penalty into a smooth regularizer. To do 
this, we first use the split variable trick to define w = wt — w`, where wt = max{w,0} and 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


396 Chapter 11. Linear Regression 


w = — min{w, 0}. Now we can replace ||w||1 with $(W} +w7). We also have to replace NLL(w) 
with NLL(wt + w7). Thus we get the following smooth, but constrained, optimization problem: 


D 
min NLL(wt—w-) +r) (wt + wz) (11.103) 


wt>0,w->0 
= — d=1 


In this case of a Gaussian likelihood, the NLL becomes a least squares loss, and the objective 
becomes a quadratic program (Section 8.5.4). One way to solve such problems is to use projected 
gradient descent (Section 8.6.1). Specifically, we can enforce the constraint by projecting onto the 
positive orthant, which we can do using wg := max(wa,0); this operation is denoted by P}. Thus 
the projected gradient update takes the following form: 


+ + + 7 
Wiza) p, ( (WE —mVNLL(w" — w ) — mre or, 
e y (2 + mVNLL(w; — wy) — màe (11.104) 


where e is the unit vector of all ones. 


11.4.9.3 Proximal gradient descent 


In Section 8.6, we introduced proximal gradient descent, which can be used to optimize smooth 
functions with non-smooth penalties, such as 1. In Section 8.6.2, we showed that the proximal 
operator for the ¢, penalty corresponds to soft thresholding. Thus the proximal gradient descent 
update can be written as 


wi+1 = SoftThreshold(w; — m VNLL(w:z), 7) (11.105) 


where the soft thresholding operator (Equation (8.134)) is applied elementwise. This is called the 
iterative soft thresholding algorithm or ISTA [DDDM04; Don95]. If we combine this with 
Nesterov acceleration, we get the method known as “fast ISTA” or FISTA [BT09], which is widely 
used to fit sparse linear models. 


11.4.9.4 LARS 


In this section, we discuss methods that can generate a set of solutions for different values of A, 
starting with the empty set, i.e., they compute the full regularization path (Section 11.4.4). These 
algorithms exploit the fact that one can quickly compute W(A;) from W(Ag_-1) if Ax © Ag—1; this is 
known as warm starting. In fact, even if we only want the solution for a single value of A, call it Ax, 
it can sometimes be computationally more efficient to compute a set of solutions, from Amax down 
to Ax, using warm-starting; this is called a continuation method or homotopy method. This is 
often much faster than directly “cold-starting” at ,; this is particularly true if A, is small. 

The LARS algorithm [Efr+04], which stands for “least angle regression and shrinkage”, is an 
example of a homotopy method for the lasso problem. This can compute w(A) for all possible values 
of À in an efficient manner. (A similar algorithm was independently invented in [OPT00b; OPT00a]). 

LARS works as follows. It starts with a large value of À, such that only the variable that is most 
correlated with the response vector y is chosen. Then is decreased until a second variable is found 
which has the same correlation (in terms of magnitude) with the current residual as the first variable, 
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where the residual at step k on the path is defined as rk = y — X. 7, wz, where Fp is the current 
active set (cf., Equation (11.83)). Remarkably, one can solve for this new value of \ analytically, 
by using a geometric argument (hence the term “least angle”). This allows the algorithm to quickly 
“jump” to the next point on the regularization path where the active set changes. This repeats until 
all the variables are added. 

It is necessary to allow variables to be removed from the current active set, even as we increase À, 
if we want the sequence of solutions to correspond to the regularization path of lasso. If we disallow 
variable removal, we get a slightly different algorithm called least angle regression or LAR. LAR 
is very similar to greedy forward selection, and a method known as least squares boosting 
(see e.g., [HTW15]). 


11.5 Regression splines * 


We have seen how we can use polynomial basis functions to create nonlinear mappings from input to 
output, even though the model remains linear in the parameters. One problem with polynomials is 
that they are a global approximation to the function. We can achieve more flexibility by using a 
series of local approximations. To do this, we just need to define a set of basis functions that have 
local support. The notion of “locality” is hard to define in high-dimensional input spaces, so in this 
section, we restrict ourselves to 1d inputs. We can then approximate the function using 


m 


f(;0) = X` w,B,(2) (11.106) 


i=1 


where B; is the i’th basis function. 

A common way to define such basis functions is to use B-splines. (“B” stands for “basis”, and the 
term “spline” refers to a flexible piece of material used by artists to draw curves.) We discuss this in 
more detail in Section 11.5.1. 


11.5.1 B-spline basis functions 


A spline is a piecewise polynomial of degree D, where the locations of the pieces are defined by a set 
of knots, tı < --- < tm. More precisely, the polynomial is defined on each of the intervals (—oo, t1), 
[t1, t2], +++, [tm,oo). The function is continuous and has continuous derivatives of orders 1,...,D—1 
at its knot points. It is common to use cubic splines, in which D = 3. This ensures the function is 
continuous, and has continuous first and second derivatives at each knot. 

We will skip the details on how B-splines are computed, since it is not relevant to our purposes. 
Suffice it to say that we can call the patsy.bs function to convert the N x 1 data matrix X into an 
N x (K+ D-+1) design matrix B, where K is the number of knots and D is the degree. (Alternatively, 
you can specify the desired number of basis functions, and let patsy work out the number and locations 
of the knots.) 

Figure 11.16 illustrates this approach, where we use B-splines of degree 0, 1 and 3, with 3 knots. 
By taking a weighted combination of these basis functions, we can get increasingly smooth functions, 
as shown in the bottom row. 

We see from Figure 11.16 that each individual basis function has local support. At any given input 
point x, only D+ 1 basis functions will be “active”. This is more obvious if we plot the design matrix 
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Figure 11.16: Illustration of B-splines of degree 0, 1 and 3. Top row: unweighted basis functions. Dots mark 
the locations of the 3 internal knots at [0.25,0.5,0.75]. Bottom row: weighted combination of basis functions 
using random weights. Generated by splines basis weighted.ipynb. Adapted from Figure 5.4 of [MKL11]. 
Used with kind permission of Osvaldo Martin. 
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Figure 11.17: Design matrix for B-splines of degree (a) 0, (b) 1 and (c) 3. We evaluate the splines on 20 
inputs ranging from 0 to 1. Generated by splines basis heatmap.ipynb. Adapted from Figure 5.6 of [MKL11]. 
Used with kind permission of Osvaldo Martin. 


B itself. Let us first consider the piecewise constant spline, shown in Figure 11.17(a). The first 
B-spline (column 1) is 1 for the first 5 observations, and otherwise 0. The second B-spline (column 
0) is 0 for the first 5 observations, 1 for the second 5, and then 0 again. And so on. Now consider 
the linear spline, shown in Figure 11.17(b). The first B-spline (column 0) goes from 1 to 0, the next 
three splines go from 0 to 1 and back to 0; and the last spline (column 4) goes from 0 to 1; this 
reflects the triangular shapes shown in the top middle panel of Figure 11.16. Finally consider the 
cubic spline, shown in Figure 11.17(c). Here the pattern of activations is smoother, and the resulting 
model fits will be smoother too. 
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Figure 11.18: Fitting a cubic spline regression model with 15 knots to a 1d dataset. Generated by 
splines cherry_ blossoms.ipynb. Adapted from Figure 5.8 of [McE20]. 


11.5.2 Fitting a linear model using a spline basis 


Once we have computed the design matrix B, we can use it to fit a linear model using least squares 
or ridge regression. (It is usually best to use some regularization.) As an example, we consider a 
dataset from [McE20, Sec 4.5], which records the the first day of the year, and the corresponding 
temperature, that marks the start of the cherry blossom season in Japan. (We use this dataset since 
it has interesting semi-periodic structure.) We fit the data using a cubic spline. We pick 15 knots, 
spaced according to quantiles of the data. The results are shown in Figure 11.18. We see that the fit 
is reasonable. Using more knots would improve the quality of the fit, but would eventually result in 
overfitting. We can select the number of knots using a model selection method, such as grid search 
plus cross validation. 


11.5.3 Smoothing splines 


Smoothing splines are related to regression splines, but use N knots, where N is the number of 
datapoints. That is, they are non-parametric models, since the number of parameters grows with the 
size of the data, rather than being fixed a priori. To avoid overfitting, smoothing splines rely on £2 
regularization. This technique is closely related to Gaussian process regression, which we discuss in 
Section 17.2. 


11.5.4 Generalized additive models 


A generalized additive model or GAM extends spline regression to the case of multidimensional 
inputs [HT90]. It does this by ignoring interactions between the inputs, and assuming the function 
has the following additive form: 


D 
f(w;0) =a+ X falza) (11.107) 


d=1 
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Figure 11.19: (a) Illustration of robust linear regression. Generated by linregRobustDemoCombined.ipynb. (b) 
Illustration of l2, €1, and Huber loss functions with 6 = 1.5. Generated by huberLossPlot.ipynb. 


where each fa is a regression or smoothing spline. This model can be fit using backfitting, which 
iteratively fits each fa to the partial residuals generated by the other terms. We can extend GAMs 
beyond the regression case (e.g., to classification) by using a link function, as in generalized linear 
models (Chapter 12). 


11.6 Robust linear regression * 


It is very common to model the noise in regression models using a Gaussian distribution with zero 
mean and constant variance, rn ~ N(0,07), where rn = Yn — w'a,y. In this case, maximizing 
likelihood is equivalent to minimizing the sum of squared residuals, as we have seen. However, if 
we have outliers in our data, this can result in a poor fit, as illustrated in Figure 11.19(a). (The 
outliers are the points on the bottom of the figure.) This is because squared error penalizes deviations 
quadratically, so points far from the line have more effect on the fit than points near to the line. 
One way to achieve robustness to outliers is to replace the Gaussian distribution for the response 
variable with a distribution that has heavy tails. Such a distribution will assign higher likelihood 
to outliers, without having to perturb the straight line to “explain” them. We discuss several possible 
alternative probability distributions for the response variable below; see Table 11.2 for a summary. 


11.6.1 Laplace likelihood 


In Section 2.7.3, we noted that the Laplace distribution is also robust to outliers. If we use this as 
our observation model for regression, we get the following likelihood: 


1 
p(yla, w,b) = Laplace(y|w' a, b) œ exp(— sly —w'2|) (11.108) 


The robustness arises from the use of |y — w'a| instead of (y — w'a)?. Figure 11.19(a) gives an 
example of the method in action. 
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Likelihood Prior Posterior Name Section 
Gaussian Uniform Point Least squares 11.2.2 
Student Uniform Point Robust regression 11.6.2 
Laplace Uniform Point Robust regression 11.6.1 
Gaussian Gaussian Point Ridge 11.3 
Gaussian Laplace Point Lasso 11.4 


Gaussian Gauss-Gamma Gauss-Gamma Bayesian lin. reg 11.7 


Table 11.2: Summary of various likelihoods, priors and posteriors used for linear regression. The likelihood 
refers to the distributional form of p(y|a,w,o”), and the prior refers to the distributional form of p(w). The 
posterior refers to the distributional form of p(w|D). “Point” stands for the degenerate distribution 6(w — Ww), 
where w is the MAP estimate. MLE is equivalent to using a point posterior and a uniform prior. 


11.6.1.1 Computing the MLE using linear programming 


We can compute the MLE for this model using linear programming. As we explain in Section 8.5.3, 
this is a way to solve a constrained optimization problems of the form 


argmine'y s.t. Av <b (11.109) 


where v € R” is the set of n unknown parameters, c'v is the linear objective function we want to 


minimize, and a! v < b; is a set of m linear constraints we must satisfy. To apply this to our problem, 
let us define v = (wi,...,wp,€1,-.-,en) E RPN, where e; = |y; — i| is the residual error for 
example i. We want to minimize the sum of the residuals, so we define c = (0,--- ,0,1,--- ,1) € RPN, 
where the first D elements are 0, and the last N elements are 1. 

We need to enforce the constraint that e; = |g; — y;|. In fact it is sufficient to enforce the constraint 
that |w'a; — y;| < ei, since minimizing the sum of the e;’s will “push down” on this constraint 
and make it tight. Since |a| < b =>» —b <a < b, we can encode |w! æ; — y;| < e; as two linear 
constraints: 


ei > wa; — yi (11.110) 
ei > —(w'a; — yi) (11.111) 


We can write Equation (11.110) as 
(£i, 0,- ,0,-1,0,--- 0) v< yi (11.112) 


where the first D entries are filled with x;, and the —1 is in the (D + i}'th entry of the vector. 
Similarly we can write Equation (11.111) as 


(—a;,0,- E ,0,—-1,0,-- i 0)" w < TMi (11.113) 
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We can write these constraints in the form Av < b by defining A € R?N*(N+D) as follows: 


xı -l 0 0. 0 
-zı -l 0 0. 0 
A=| £ 0 —1 0. 0 (11.114) 
-z2 0 —1 0.. 0 
and defining b € R?N as 
b= (y1, —Y1, Y2; —Y2°°° „YN, —YN) (11.115) 


11.6.2 Student-t likelihood 


In Section 2.7.1, we discussed the robustness properties of the Student distribution. To use this in a 


regression context, we can just make the mean be a linear function of the inputs, as proposed in 
[Ze176]: 


plylæ, w,0?, v) = T(y|w'a, o°, v) (11.116) 


We can fit this model using SGD or EM (see [Mur23] for details). 


11.6.3 Huber loss 


An alternative to minimizing the NLL using a Laplace or Student likelihood is to use the Huber 
loss, which is defined as follows: 


_ r? /2 if |r| <ô 
Lhuber(r, ô) = a = 67/2 if Ir] > 5 (11.117) 


This is equivalent to 4 for errors that are smaller than 6, and is equivalent to 4, for larger errors. 
See Figure 5.3 for a plot. 

The advantage of this loss function is that it is everywhere differentiable. Consequently optimizing 
the Huber loss is much faster than using the Laplace likelihood, since we can use standard smooth 
optimization methods (such as SGD) instead of linear programming. Figure 11.19 gives an illustration 
of the Huber loss function in action. The results are qualitatively similiar to the Laplace and Student 
methods. 

The parameter 6, which controls the degree of robustness, is usually set by hand, or by cross- 
validation. However, [Bar19] shows how to approximate the Huber loss such that we can optimize 6 
by gradient methods. 


11.6.4 RANSAC 


In the computer vision community, a common approach to robust regression is to use RANSAC, 
which stands for “random sample consensus” [F'B81]. This works as follows: we sample a small initial 
set of points, fit the model to them, identify outliers wrt this model (based on large residuals), remove 
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the outliers, and then refit the model to the inliers. We repeat this for many random initial sets and 
pick the best model. 

A deterministic alternative to RANSAC is the following iterative scheme: intially we assume that 
all datapoints are inliers, and we fit the model to compute wo; then, for each iteration t, we identify 
the outlier points as those with large residual under the model wW;, remove them, and refit the model 
to the remaining points to get W,+1. Even though this hard thresholding scheme makes the problem 
nonconvex, this simple scheme can be proved to rapidly converge to the optimal estimate under some 
reasonable assumptions [Muk+19; Sug+19]. 


11.7 Bayesian linear regression * 


We have seen how to compute the MLE and MAP estimate for linear regression models under various 
priors. In this section, we discuss how to compute the posterior over the parameters, p(@|D). For 
simplicity, we assume the variance is known, so we just want to compute p(w|D, 07). See the sequel 
to this book, [Mur23], for the general case. 


11.7.1 Priors 
For simplicity, we will use a Gaussian prior: 

p(w) =N (w| ù, ¥) (11.118) 
This is a small generalization of the prior that we use in ridge regression (Section 11.3). See the 


sequel to this book, [Mur23], for a discussion of other priors. 


11.7.2 Posteriors 
We can rewrite the likelihood in terms of an MVN as follows: 
N 
p(D\w,o?) = | | p(ynlw'x, 0?) =N(y|Xw, oy) (11.119) 
n=1 


where Iy is the N x N identity matrix. We can then use Bayes rule for Gaussians (Equation (3.37)) 
to derive the posterior, which is as follows: 


p(w|X, y,07) x N(w| %, S)N(y|Xw, o7Iv) = N(w| ®, £) (11.120) 
48 (5H wia : = X"y) (11.121) 
Sağ +5 l xtX) (11.122) 


where W is the posterior mean, and S is the posterior covariance. 
If H#= 0 and X= 77, then the posterior mean becomes = 4 £ X'y. If we define \ = z, we 
recover the ridge regression estimate, ®= (AI + XTX)-tX!y, which matches Equation (11.57). 
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11.7.3 Example 


Suppose we have a 1d regression model of the form f(x; w) = wo + wi21, where the true parameters 
are wo = —0.3 and wı = 0.5. We now perform inference p(w|D) and visualize the 2d prior and 
posterior as the size of the training set N increases. 

In particular, in Figure 11.20 (which inspired the front cover of this book), we plot the likelihood, 
the posterior, and an approximation to the posterior predictive distribution.t Each row plots these 
distributions as we increase the amount of training data, N. We now explain each row: 


e In the first row, N = 0, so the posterior is the same as the prior. In this case, our predictions are 
“all over the place”, since our prior is essentially uniform. 


e In the second row, N = 1, so we have seen one data point (the blue circle in the plot in the third 
column). Our posterior becomes constrained by the corresponding likelihood, and our predictions 
pass close to the observed data. However, we see that the posterior has a ridge-like shape, reflecting 
the fact that there are many possible solutions, with different slopes/intercepts. This makes sense 
since we cannot uniquely infer two parameters (wo and w1) from one observation. 


e In the third row, N = 2. In this case, the posterior becomes much narrower since we have two 
constraints from the likelihood. Our predictions about the future are all now closer to the training 
data. 


e In the fourth (last) row, N = 100. Now the posterior is essentially a delta function, centered on 
the true value of w, = (—0.3,0.5), indicated by a white cross in the plots in the first and second 
columns. The variation in our predictions is due to the inherent Gaussian noise with magnitude 


a. 


This example illustrates that, as the amount of data increases, the posterior mean estimate, 
A= E[w|D], converges to the true value w, that generated the data. We thus say that the Bayesian 
estimate is a consistent estimator (see Section 5.3.2 for more details). We also see that our posterior 
uncertainty decreases over time. This is what we mean when we say we are “learning” about the 
parameters as we see more data. 


11.7.4 Computing the posterior predictive 


We have discussed how to compute our uncertainty about the parameters of the model, p(w|D). 
But what about the uncertainty associated with our predictions about future outputs? Using 
Equation (3.38), we can show that the posterior predictive distribution at a test point x is also 
Gaussian: 


pluje, D.o?) = f N(yle"w,0°)\A (w fi, 8)dw (11.123) 


=N (y| p' x,8? (x)) (11.124) 


4. To approximate this, we draw some samples from the posterior, ws ~ N (u, ©), and then plot the line E [y|z, ws], 
where x ranges over [—1, 1], for each sampled parameter value. 
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likelihood prior/posterior data space 


Figure 11.20: Sequential Bayesian inference of the parameters of a linear regression model p(y|x) = N(y|wo + 
wiai,0”). Left column: likelihood function for current data point. Middle column: posterior given first N 
data points, p(wo, wi|@1:n,Y1:N,07). Right column: samples from the current posterior predictive distribution. 
Row 1: prior distribution (N =0). Row 2: after 1 data point. Row 3: after 2 data points. Row 4: after 100 
data points. The white cross in columns 1 and 2 represents the true parameter value; we see that the mode 
of the posterior rapidly converges to this point. The blue circles in column 3 are the observed data points. 
Adapted from Figure 3.7 of [Bis06]. Generated by linreg_ 2d_bayes_ demo.ipynb. 


where 6? (x) £ o? + a7 S x is the variance of the posterior predictive distribution at point æ 


after seeing the N training examples. The predicted variance depends on two terms: the variance 
of the observation noise, 0”, and the variance in the parameters, =. The latter translates into 
variance about observations in a way which depends on how close æ is to the training data D. This is 
illustrated in Figure 11.21(b), where we see that the error bars get larger as we move away from the 
training points, representing increased uncertainty. This can be important for certain applications, 
such as active learning, where we choose where to collect training data (see Section 19.4). 

In some cases, it is computationally intractable to compute the parameter posterior, p(w|D). In 
such cases, we may choose to use a point estimate, Ù, and then to use the plugin approximation. 
This gives 


p(yle,D,0%) = | N(yle"w, 0?) 5(w — w)dw = piyli, o’). (11.125) 


We see that the posterior predictive variance is constant, and independent of the data, as illustrated in 
Figure 11.21(a). If we sample a parameter from this posterior, we will always recover a single function, 
as shown in Figure 11.21(c). By contrast, if we sample from the true posterior, w, ~ p(w|D, o°), we 
will get a range of different functions, as shown in Figure 11.21(d), which more accurately reflects 
our uncertainty. 
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Figure 11.21: (a) Plugin approximation to predictive density (we plug in the MLE of the parameters) when 
fitting a second degree polynomial to some 1d data. (b) Posterior predictive density, obtained by integrating 
out the parameters. Black curve is posterior mean, error bars are 2 standard deviations of the posterior 
predictive density. (c) 10 samples from the plugin approrimation to posterior predictive distribution. (d) 10 
samples from the true posterior predictive distribution. Generated by linreg_post_pred_ plot.ipynb. 


11.7.5 The advantage of centering 


The astute reader might notice that the shape of the 2d posterior in Figure 11.20 is an elongated 
ellipse (which eventually collapses to a point as N — oo). This implies that there is a lot of posterior 
correlation between the two parameters, which can cause computational difficulties. 

To understand why this happens, note that each data point induces a likelihood function corre- 
sponding to a line which goes through that data point. When we look at all the data together, we see 
that predictions with maximum likelihood must correspond to lines that go through the mean of the 
data, (Z,y). There are many such lines, but if we increase the slope, we must decrease the intercept. 
Thus we can think of the set of high probability lines as spinning around the data mean, like a wheel 
of fortune.” This correlation between wọ and w: is why the posterior has the form of a diagonal line. 
(The Gaussian prior converts this into an elongated ellipse, but the posterior correlation still persists 
until the sample size causes the posterior to shrink to a point.) 

It can be hard to compute such elongated posteriors. One simple solution is to center the input 
data, i.e., by using 2), = £n — ©. Now the lines can pivot around the origin, reducing the posterior 


5. This analogy is from [Mar18, p96]. 
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Figure 11.22: Posterior samples of p(wo,wı|D) for 1d linear regression model p(y|z,0) = N(y|wo + 
wıx,o’) with a Gaussian prior. (a) Original data. (b) Centered data. Generated by lin- 
reg_2d_bayes_centering_pymc3.ipynb. 


correlation between wo and w. See Figure 11.22 for an illustration. (We may also choose to divide 
each x, by the standard deviation of that feature, as discussed in Section 10.2.8.) 

Note that we can convert the posterior derived from fitting to the centered data back to the original 
coordinates by noting that 


y = wh tu’ = wg + wi (e -— T) = (wo — wT) + wir (11.126) 


Thus the parameters on the uncentered data are wo = wh — wi T and w = w. 


11.7.6 Dealing with multicollinearity 


In many datasets, the input variables can be highly correlated with each other. Including all of 
them does not generally harm predictive accuracy (provided you use a suitable prior or regularizer to 
prevent overfitting). However, it can make interpretation of the coefficients more difficult. 

To illustrate this, we use a toy example from [McE20, Sec 6.1]. Suppose we have a dataset of N 
people in which we record their heights h;, as well as the length of their left legs l; and right legs r;. 
Suppose h; ~ MN (10, 2), so the average height is h = 10 (in unspecified units). Suppose the length of 
the legs is some fraction p; ~ Unif(0.4,0.5) of the height, plus a bit of Gaussian noise, specifically 

Now suppose we want to predict the height of a person given measurement of their leg lengths. 
(I did mention this is a toy example!) Since both left and right legs are noisy measurements 
of the unknown quantity, it is useful to use both of them. So we use linear regression to fit 
p(hll,r) =N (hla + Bl + 6,7, 07). We use vague priors, a, 3), 8, ~ N (0,100), and o ~ Expon(1). 

Since the average leg length is I = 0.45h = 4.5, we might expect each 8 coefficient to be around 
h/l = 10/4.5 = 2.2. However, the posterior marginals shown in Figure 11.23 tell a different story: we 
see that the posterior mean of 6; is near 2.6, but 8, is near -0.6. Thus it seems like the right leg feature 
is not needed. This is because the regression coefficient for feature j encodes the value of knowing xj 
given that all the other features x_; are already known, as we discussed in Section 11.2.2.1. If we 
already know the left leg, the marginal value of also knowing the right leg is small. However, if we 
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Figure 11.23: Posterior marginals for the parameters in the multi-leg example. Generated by 
multi_ collinear_legs_ numpyro.ipynb. 
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Figure 11.24: Posteriors for the multi-leg example. (a) Joint posterior p({1, Br|D) (b) Posterior of p(8ı + 
8,|data). Generated by multi_ collinear_legs_ numpyro.ipynb. 


rerun this example with slightly different data, we may reach the opposite conclusion, and favor the 
right leg over the left. 

We can gain more insight by looking at the joint distribution p(81, 3,|D), shown in Figure 11.24a. 
We see that the parameters are very highly correlated, so if 8, is large, then 6; is small, and vice 
versa. The marginal distribution for each parameter does not capture this. However, it does show 
that there is a lot of uncertainty about each parameter, showing that they are non-identifiable. 
However, their sum is well-determined, as can be seen from Figure 11.24b, where we plot p(;+,|D); 
this is centered on 2.2, as we might expect. 

This example goes to show that we must be careful trying to interpret the significance of individual 
coefficient estimates in a model, since they do not mean much in isolation. 


11.7.7 Automatic relevancy determination (ARD) * 


Consider a linear regression model with known observation noise but unknown regression weights, 
N(y|Xw, oI). Suppose we use a Gaussian prior for the weights, w; ~ N (0,1/a;), where a; is the 
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precision of the j7’th parameter. Now suppose we estimate the prior precisions as follows: 


â = argmax p(y|X, a) (11.127) 
where 
p(X, a) = | p(y|Xw,o%)p(w|0,diag(a)*)dw (11.128) 


is the marginal likelihood. This is an example of empirical Bayes, since we are estimating the prior 
from data. We can view this as a computational shortcut to a fully Bayesian approach. However, 
there are additional advantages. In particular, suppose, after estimating œ, we compute the MAP 
estimate 


w = argmax N (w|0, a*) (11.129) 


This results in a sparse estimate for w, which is perhaps surprising given that the Gaussian prior for 
w is not sparsity promoting. The reasons for this are explained in the sequel to this book. 

This technique is known as sparse Bayesian learning [Tip01| or automatic relevancy deter- 
mination (ARD) [Mac95; Nea96]. It was originally developed for neural networks (where sparsity 
is applied to the first layer weights), but here we apply it to linear models. See also Section 17.4.1, 
where we apply it kernelized linear models. 


11.8 Exercises 
Exercise 11.1 [Multi-output linear regression *] 


(Source: Jaakkola.) 


Consider a linear regression model with a 2 dimensional response vector y; € R?. Suppose we have some 
binary input data, x; € {0,1}. The training data is as follows: 


x |Y 
0 (-1, —2)7 
0 (-2, -1)? 
1} (1,1)7 
1 | (1,2)7 
1 | (2,1)7 
Let us embed each x; into 2d using the following basis function: 
(0) = (1,0)", (1) = (0,1) (11.130) 
The model becomes 
d = Wela) (11.131) 


where W is a 2 x 2 matrix. Compute the MLE for W from the above data. 


Exercise 11.2 [Centering and ridge regression] 


Assume that % = 0, so the input data has been centered. Show that the optimizer of 


J(w, wo) = (y — Xw — wo)" (y — Xw — wol) + Aw" w (11.132) 
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is 


tio =F (11.133) 
w = (XTX +1) Xy (11.134) 
Exercise 11.3 [Partial derivative of the RSS *| 
Let RSS(w) = ||Xw — y|| be the residual sum of squares. 
a. Show that 
 Rg5(w) = AkWk — C (11.135) 
Dui = AkWk — Ck . 
G2 > tik = 2|: kll? (11.136) 
i=1 
Ch = 25` ZiklYi — wrtii k) = Qa rk (11.137) 
i=1 


where w—p = w without component k, £i, —ẹ is x; without component k, and rk = y — we pE: k is the 
residual due to using all the features except feature k. Hint: Partition the weights into those involving k 
and those not involving k. 


b. Show that if 52 RSS(w) = 0, then 


UEP K 


(11.138) 


td, = RE, 
[|æ x|? 


Hence when we sequentially add features, the optimal weight for feature k is computed by computing 
orthogonally projecting x.,, onto the current residual. 


Exercise 11.4 [Reducing elastic net to lasso] 


Define 

Ji(w) = |y — Xw||? + Az||w|]2 + Ai llwlla (11.139) 
and 

Jo(w) = | — Šu? + lwll (11.140) 
where ||w||? = ||/w|||3 = Dw? is the squared 2-norm, ||w||1 = X; |w:| is the 1-norm, c = (1 + Nee 2, and 


Z= Aau) y= Ge (11.141) 


Show 

argmin Jı (w) = c(argmin J2(w)) (11.142) 
i.e. 

J(cw) = Jo(w) (11.143) 


and hence that one can solve an elastic net problem using a lasso solver on modified data. 
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Exercise 11.5 [Shrinkage in linear regression *| 
(Source: Jaakkola.) Consider performing linear regression with an orthonormal design matrix, so ||æ. ||} = 1 


for each column (feature) k, and #7),a.,; = 0, so we can estimate each parameter wp separately. 


Figure 10.15b plots wr vs ck = 2YT E: k, the correlation of feature k with the response, for 3 different 
estimation methods: ordinary least squares (OLS), ridge regression with parameter A2, and lasso with 
parameter A1. 


a. Unfortunately we forgot to label the plots. Which method does the solid (1), dotted (2) and dashed (3) 
line correspond to? 


b. What is the value of Ai? 
c. What is the value of A2? 


Exercise 11.6 [EM for mixture of linear regression experts] 


Derive the EM equations for fitting a mixture of linear regression experts. 
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12.1 Introduction 


In Chapter 10, we discussed logistic regression, which, in the binary case, corresponds to the model 
p(y|x, w) = Ber(y|o(w'a)). In Chapter 11, we discussed linear regression, which corresponds to the 
model p(y|x, w) = N(y|w'x,o7). These are obviously very similar to each other. In particular, the 
mean of the output, E [y|a, w], is a linear function of the inputs x in both cases. 

It turns out that there is a broad family of models with this property, known as generalized 
linear models or GLMs [MN89]. 

A GLM is a conditional version of an exponential family distribution (Section 3.4), in which the 
natural parameters are a linear function of the input. More precisely, the model has the following 
form: 


n nA n 
lala, w0?) = exp [HM EMH) tog yaa?) (121) 


T £n is the (input dependent) natural parameter, A(nn) is the log normalizer, T(y) = y 


il 


where m = w 
is the sufficient statistic, and o? is the dispersion term. 
We will denote the mapping from the linear inputs to the mean of the output using un = 0-+(nn), 
where the function £ is known as the link function, and 47} is known as the mean function. 
Based on the results in Section 3.4.3, we can show that the mean and variance of the response 
variable are as follows: 


j [yn|@n, w, 0°] = A'(1m) = t) (12.2) 
V [yn|En, w, 0°] a A" (n)a? (12.3) 


12.2 Examples 


In this section, we give some examples of widely used GLMs. 


1. Technically speaking, GLMs use a slight extension of the natural exponential family known as the exponential 


dispersion family. For a scalar variable, this has the form p(y|n, o?) = h(y, o?) exp [250], Here o? is called 


the dispersion parameter. For fixed o?, this is a natural exponential family. 
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12.2.1 Linear regression 


Recall that linear regression has the form 


1 
P(Yn|@n, w, 0°) = 2 exp( 5 (Yn w'an)”) (12.4) 
210 20 
Hence 
log p(yn|an, w, 0°) = l (Yn = Mm) l log(2707) (12.5) 
-o 202 2 
where Nn = w'z,. We can write this in GLM form as follows: 
na 2 
Rm- “9 1 
log p(Yn lan, w, 0?) = 2> vn + log(270?) (12.6) 
o? 2 (o? 


We see that A(m) = 77/2 and hence 


T 
a 
I 

3 
3 
I 
€ 
8 
3 


12.2.2 Binomial regression 


If the response variable is the number of successes in Np trials, yn € {0,..., Nn}, we can use 
binomial regression, which is defined by 


PlYn|En, Nn, w) = Bin(yn|o(w' æn), Nn) (12.9) 


We see that binary logistic regression is the special case when Nn = 1. 
The log pdf is given by 


Nn 

log p(Yn|@n, Nn, w) = Yn log un + (Nn — Yn) log(1 — un) + log e ) (12.10) 
n Nn 

= yn log( >=") + Nn log(1 — jin) + log p ) (12.11) 


where Un = o (nn). To rewrite this in GLM form, let us define 


1 O 14ean 


L-enw'tn e-w'en 


n 1 
Tn = log | = log | = log =a w' an (12.12) 


Hence we can write binomial regression in GLM form as follows 
log p(Yn|@n, Nn, w) = ynn — A(n) + hlyn) (12.13) 
where h(yn) = log Gs and 


A(n) = -Nn log(1 — pn) = Nn log(1 + e”) (12.14) 
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Hence 
dA N,e™ Nn 
[vn] dm 1Ltem 1+e™ m (12.16) 
and 
dA 
V [yn] = aa Nnbn(1 — Hn) (12.16) 


12.2.3 Poisson regression 


If the response variable is an integer count, yn € {0,1,...}, we can use Poisson regression, which 
is defined by 


P(Yn|@n,W) = Poi(yn| exp(w' an) (12.17) 
where 
i p” 
Poi(yļu) = e” ri (12.18) 


is the Poisson distribution. Poisson regression is widely used in bio-statistical applications, where 
Yn might represent the number of diseases of a given person or place, or the number of reads at a 
genomic location in a high-throughput sequencing context (see e.g., [Kua-+09]). 

The log pdf is given by 


log p(lyn|£n, w) = Yn log Hn E Hn E log(yn!) (12.19) 


where un = exp(w'a,). Hence in GLM form we have 


log p(yn|£n, W) = Ynn — A(n) + hlyn) (12.20) 
where n = log(un) = W En, A(n) = In = e”, and h(Yn) aa log(yn!). Hence 
i [yn] = -n r Hn (12.21) 
dim 
and 
dA 
n| = za = Nn — Ti 12.22 
V [un] T li (12.22) 


12.3 GLMs with non-canonical link functions 


We have seen how the mean parameters of the output distribution are given by u = £7! (n), where 
the function @ is the link function. There are several choices for this function, as we now discuss. 

The canonical link function £ satisfies the property that 0 = (u), where 0 are the canonical 
(natural) parameters. Hence 


0 = E(u) = (n) =n (12.23) 
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This is what we have assumed so far. For example, for the Bernoulli distribution, the canonical 
parameter is the log-odds 0 = log(y/(1 — u)), which is given by the logit transform 


6 = L(u) = logit (2) = log (+) (12.24) 


The inverse of this is the sigmoid or logistic function u = 0(@) = 1/(1 + e7°). 
However, we are free to use other kinds of link function. For example, the probit link function 
has the form 


n = Lu) = $~ (u) (12.25) 


Another link function that is sometimes used for binary responses is the complementary log-log 
function 


n = Ku) = log(— log(1 — 1)) (12.26) 


This is used in applications where we either observe 0 events (denoted by y = 0) or one or more 
(denoted by y = 1), where events are assumed to be governed by a Poisson distribution with rate A. 
Let E be the number of events. The Poisson assumption means p(£ = 0) = exp(—A) and hence 


ply = 0) = (1 — u) = p(E = 0) = exp(—A) (12.27) 
Thus à = — log(1— u). When å is a function of covariates, we need to ensure it is positive, so we use 
A = e”, and hence 

n = log(A) = log(— log(1 — 1) (12.28) 


12.4 Maximum likelihood estimation 


GLMs can be fit using similar methods to those that we used to fit logistic regression. In particular, 
the negative log-likelihood has the following form (ignoring constant terms): 


N 
1 
NLL(w) = — log p(Dlw) = — DA (12.29) 
n=1 
where 


where m = wl æn. For notational simplicity, we will assume g? = 1. 


We can compute the gradient for a single term as follows: 


a ln — Abn Onn 


pea 1 — hes 
Gn ae Onn Ow = (Yn — A (n))En = (Yn — bn) Bn (12.31) 


where un = f(w'a), and f is the inverse link function that maps from canonical parameters to 


mean parameters. For example, in the case of logistic regression, f(n) = a(n), so we recover 
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Data DummyRegressor Ridge PoissonRegressor 


o 5 10 15 20 25 30 -1 o 1 2 3 4 -1 o 1 2 3 4 -1 o 1 2 3 4 
y (observed Frequency) y_pred (predicted expected Frequency) y_pred (predicted expected Frequency) y_pred (predicted expected Frequency) 


Figure 12.1: Predictions of insurance claim rates on the test set. (a) Data. (b) Constant predictor. (c) Linear 
regression. (d) Poisson regression. Generated by poisson_ regression  insurance.ipynb. 


Equation (10.21). This gradient expression can be used inside SGD, or some other gradient method, 
in the obvious way. 
The Hessian is given by 


g Ogn 
H Fwout NLL(w) = — 2 Ju (12.32) 
where 
Ogn Ogn Oln 
ir a Sat = -gn f'(w'a,)a! (12.33) 
Hence 
N 
H= X f Gene (12.34) 
n=1 


For example, in the case of logistic regression, f(7) = (7), and f(m) = o(Nn)(1 — o(n)), so we 
recover Equation (10.23). In general, we see that the Hessian is positive definite, since f’(7,) > 0; 
hence the negative log likelihood is convex, so the MLE for a GLM is unique (assuming f(m) > 0 
for all n). 

Based on the above results, we can fit GLMs using gradient based solvers in a manner that is very 
similar to how we fit logistic regression models. 


12.5 Worked example: predicting insurance claims 


In this section, we give an example of predicting insurance claims using linear and Poisson regression.’. 
The goal is to predict the expected number of insurance claims per year following car accidents. The 
dataset consists of 678k examples with 9 features, such as driver age, vehicle age, vehicle power, 


2. This example is from  https://scikit-learn.org/stable/auto_examples/linear_model/plot_poisson_ 
regression_non_normal_loss.html 
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Name MSE MAE Deviance 
Dummy 0.564 0.189 0.625 
Ridge 0.560 0.177 0.601 
Poisson 0.560 0.186 0.594 


Table 12.1: Performance metrics on the test set. MSE = mean squared error. MAE = mean absolute error. 
Deviance = Poisson deviance. 


etc. The target is the frequency of claims, which is the number of claims per policy divided by the 
exposure (i.e., the duration of the policy in years). 

We plot the test set in Figure 12.1(a). We see that for 94% of the policies, no claims are made, so 
the data has lots of 0s, as is typical for count and rate data. The average frequency of claims is 10%. 
This can be converted into a dummy model, which always predicts this constant. This results in the 
predictions shown in Figure 12.1(b). The goal is to do better than this. 

A simple approach is to use linear regression, combined with some simple feature engineering 
(binning the continuous values, and one-hot encoding the categoricals). (We use a small amount of £2 
regularization, so technically this is ridge regression.) This gives the results shown in Figure 12.1(c). 
This is better than the baseline, but still not very good. In particular, it can predict negative 
outcomes, and fails to capture the long tail. 

We can do better using Poisson regression, using the same features but a log link function. The 
results are shown in Figure 12.1(d). We see that predictions are much better. 

An interesting question is how to quantify performance in this kind of problem. If we use mean 
squared error, or mean absolute error, we may conclude from Table 12.1 that ridge regression is 
better than Poisson regression, but this is clearly not true, as shown in Figure 12.1. Instead it is 
more common to measure performance using the deviance, which is defined as 


Dy, È) = 2 (log p(yilis) — log p(yilpi)) (12.35) 


where u; is the predicted parameters for the i’th example (based on the input features x; and the 
training set D), and už is the optimal parameter estimated by fitting the model just to the true 
output y;. (This is the so-called saturated model, that perfectly fits the test set.) In the case of 
Poisson regression, we have už = y;. Hence 


Dy, u) = 2» [(yi log yi — yi — log(yi!)) — (yi log fi — fi — log(yi!))] (12.36) 


t 


-25 [inog +a -v (12.37) 


By this metric, the Poisson model is clearly better (see last column of Table 12.1). 

We can also compute a calibration plot, which plots the actual frequency vs the predicted 
frequency. To compute this, we bin the predictions into intervals, and then count the empirical 
frequency of claims for all examples whose predicted frequency falls into that bin. The results 
are shown in Figure 12.2. We see that the constant baseline is well calibrated, but of course it is 
not very accurate. The ridge model is miscalibrated in the low frequency regime. In particular, it 
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DummyRegressor() Ridge(alpha=1e-06) PoissonRegressor(alpha=1e-12, max_iter=300) 


-- predictions 


-*- predictions 
~e~ observations 


=x- predictions 
=e- observations 


=e- observations 


Mean Frequency (y_pred) 


Mean Frequency (y_pred) 
eli 
Mean Frequency (y_pred) 
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0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 
Fraction of samples sorted by y_pred Fraction of samples sorted by y_pred 


0.0 0.2 0.4 0.6 0.8 1.0 
Fraction of samples sorted by y_pred 


Figure 12.2: Calibration plot for insurance claims prediction. Generated by pois- 
son_regression_insurance.ipynb. 


underestimates the total number of claims in the test set to be 10,693, whereas the truth is 11,935. 
The Poisson model is better calibrated (i.e., when it predicts examples will have a high claim rate, 
they do in fact have a high claim rate), and it predicts the total number of claims to be 11,930. 
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PART IlI 


Deep Neural Networks 


1 3 Neural Networks for Tabular Data 


13.1 Introduction 


In Part II, we discussed linear models for regression and classification. In particular, in Chap- 
ter 10, we discussed logistic regression, which, in the binary case, corresponds to the model 
p(y|z,w) = Ber(y|o(w'x)), and in the multiclass case corresponds to the model p(y|a,W) = 
Cat(y|softmax(Wa)). In Chapter 11, we discussed linear regression, which corresponds to the 
model p(y|z, w) = N(y|w'ax,o7). And in Chapter 12, we discussed generalized linear models, which 
generalizes these models to other kinds of output distributions, such as Poisson. However, all these 
models make the strong assumption that the input-output mapping is linear. 

A simple way of increasing the flexibility of such models is to perform a feature transformation, by 
replacing x with (æ). For example, we can use a polynomial transform, which in 1d is given by 
(x) = [1,2,x7,x,...], as we discussed in Section 1.2.2.2. This is sometimes called basis function 
expansion. The model now becomes 


f(x;0) = Wo(x) +b (13.1) 


This is still linear in the parameters 0 = (W, b), which makes model fitting easy (since the negative 
log-likelihood is convex). However, having to specify the feature transformation by hand is very 
limiting. 

A natural extension is to endow the feature extractor with its own parameters, 02, to get 


f(x; 0) = W(x; 02) +b (13.2) 


where 0 = (0), 02) and 0; = (W,b). We can obviously repeat this process recursively, to create more 
and more complex functions. If we compose L functions, we get 


f(@; 0) = fifre- (fr(@))-++)) (13.3) 


where f(a) = f(x; 0e) is the function at layer £. This is the key idea behind deep neural networks 
or DNNs. 

The term “DNN” actually encompasses a larger family of models, in which we compose differentiable 
functions into any kind of DAG (directed acyclic graph), mapping input to output. Equation (13.3) is 
the simplest example where the DAG is a chain. This is known as a feedforward neural network 
(FFNN) or multilayer perceptron (MLP). 

An MLP assumes that the input is a fixed-dimensional vector, say « € R?. It is common to 
call such data “structured data” or “tabular data”, since the data is often stored in an N x D 
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Table 13.1: Truth table for the XOR (exclusive OR) function, y = xı Y x2. 


design matrix, where each column (feature) has a specific meaning, such as height, weight, age, 
etc. In later chapters, we discuss other kinds of DNNs that are more suited to “unstructured 
data” such as images and text, where the input data is variable sized, and each individual element 
(e.g., pixel or word) is often meaningless on its own. In particular, in Chapter 14, we discuss 
convolutional neural networks (CNN), which are designed to work with images; in Chapter 15, 
we discuss recurrent neural networks (RNN) and transformers, which are designed to work 
with sequences; and in Chapter 23, we discuss graph neural networks (GNN), which are designed 
to work with graphs. 

Although DNNs can work well, there are often a lot of engineering details that need to be addressed 
to get good performance. Some of these details are discussed in the supplementary material to this 
book, available at probml.ai. There are also various other books that cover this topic in more depth 
(e.g., [Zha+20; Cho21; Gérl9; GBC16; Raf22]), as well as a multitude of online courses. For a more 
theoretical treatment, see e.g., [Ber+21; Cal20; Aro+21; RY21]. 


13.2 Multilayer perceptrons (MLPs) 


In Section 10.2.5, we explained that a perceptron is a deterministic version of logistic regression. 
Specifically, it is a mapping of the following form: 


f(;0) =1(w'e +b > 0) = H(w'ax +b) (13.4) 


where H(a) is the heaviside step function, also known as a linear threshold function. Since 
the decision boundaries represented by perceptrons are linear, they are very limited in what they can 
represent. In 1969, Marvin Minsky and Seymour Papert published a famous book called Perceptrons 
[MP69] in which they gave numerous examples of pattern recognition problems which perceptrons 
cannot solve. We give a specific example below, before discussing how to solve the problem. 
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9.2 
—0.2 0.0 


(a) (b) 


Figure 13.1: (a) Illustration of the fact that the XOR function is not linearly separable, but can be separated by 
the two layer model using Heaviside activation functions. Adapted from Figure 10.6 of [Gér19]. Generated by 
xor_heaviside.ipynb. (b) A neural net with one hidden layer, whose weights have been manually constructed 
to implement the XOR function. hı is the AND function and hə is the OR function. The bias terms are 
implemented using weights from constant nodes with the value 1. 


13.2.1 The XOR. problem 


One of the most famous examples from the Perceptrons book is the XOR. problem. Here the goal 
is to learn a function that computes the exclusive OR of its two binary inputs. The truth table for 
this function is given in Table 13.1. We visualize this function in Figure 13.1a. It is clear that the 
data is not linearly separable, so a perceptron cannot represent this mapping. 

However, we can overcome this problem by stacking multiple perceptrons on top of each other. 
This is called a multilayer perceptron (MLP). For example, to solve the XOR problem, we can 
use the MLP shown in Figure 13.1b. This consists of 3 perceptrons, denoted hı, hz and y. The nodes 
marked x are inputs, and the nodes marked 1 are constant terms. The nodes hı and hg are called 
hidden units, since their values are not observed in the training data. 

The first hidden unit computes hı = x21 A £2 by using appropriately set weights. (Here ^ is the 
AND operation.) In particular, it has inputs from zı and x2, both weighted by 1.0, but has a bias 
term of -1.5 (this is implemented by a “wire” with weight -1.5 coming from a dummy node whose 
value is fixed to 1). Thus hı will fire iff x; and x2 are both on, since then 


wig -— b, = [1.0,1.0]' [1,1] — 1.5 = 0.5 > 0 (13.5) 


1. The term “unstructured data” is a bit misleading, since images and text do have structure. For example, neighboring 
pixels in an image are highly correlated, as are neighboring words in a sentence. Indeed, it is precisely this structure 
that is exploited (assumed) by CNNs and RNNs. By contrast, MLPs make no assumptions about their inputs. This is 
useful for applications such as tabular data, where the structure (dependencies between the columns) is usually not 
obvious, and thus needs to be learned. We can also apply MLPs to images and text, as we will see, but performance 
will usually be worse compared to specialized models, such as as CNNs and RNNs. (There are some exceptions, such 
as the MLP-mixer model of [‘Tol+-21], which is an unstructured model that can learn to perform well on image and 
text data, but such models need massive datasets to overcome their lack of inductive bias.) 
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Similarly, the second hidden unit computes h2 = zı V x2, where V is the OR operation, and the third 
computes the output y = hi A h2, where h = 7h is the NOT (logical negation) operation. Thus y 
computes 


y = f (v1, £2) = (£1 A £2) A (z1 V £2) (13.6) 


This is equivalent to the XOR function. 

By generalizing this example, we can show that an MLP can represent any logical function. 
However, we obviously want to avoid having to specify the weights and biases by hand. In the rest of 
this chapter, we discuss ways to learn these parameters from data. 


13.2.2 Differentiable MLPs 


The MLP we discussed in Section 13.2.1 was defined as a stack of perceptrons, each of which involved 
the non-differentiable Heaviside function. This makes such models difficult to train, which is why 
they were never widely used. However, suppose we replace the Heaviside function H : R > {0,1} 
with a differentiable activation function ọ : R —> R. More precisely, we define the hidden units 
zı at each layer l to be a linear transformation of the hidden units at the previous layer passed 
elementwise through this activation function: 


zı = filzı-1) = pı (bi + Wizı—1) (13.7) 
or, in scalar form, 


Kı—ı 
zei = i | bri + XO wing 21-1 (13.8) 


j=1 
The quantity that is passed to the activation function is called the pre-activations: 
aı = bi + Wiz-1 (13.9) 


so zı = yıla). 

If we now compose L of these functions together, as in Equation (13.3), then we can compute 
the gradient of the output wrt the parameters in each layer using the chain rule, also known as 
backpropagation, as we explain in Section 13.3. (This is true for any kind of differentiable activation 
function, although some kinds work better than others, as we discuss in Section 13.2.3.) We can 
then pass the gradient to an optimizer, and thus minimize some training objective, as we discuss in 
Section 13.4. For this reason, the term “MLP” almost always refers to this differentiable form of the 
model, rather than the historical version with non-differentiable linear threshold units. 


13.2.3 Activation functions 


We are free to use any kind of differentiable activation function we like at each layer. However, if we 
use a linear activation function, pe(a) = cea, then the whole model reduces to a regular linear model. 
To see this, note that Equation (13.3) becomes 


f(@;0) = Wrcr(W ricer (Wi2)---)) « WLW- Wie = We (13.10) 
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Sigmoid activation function Activation functions 


Saturating 


Saturating 


Linear 


f — Sigmoid 
--- Tanh 
-1.0 pete —-- ReLU 


AN 3 0 z j 5 3 3 3 z 
(a) (b) 


Figure 13.2: (a) Illustration of how the sigmoid function is linear for inputs near 0, but saturates for large 
positive and negative inputs. Adapted from 11.1 of [Gér19]. (b) Plots of some neural network activation 
functions. Generated by activation fun_ plot.ipynb. 


where we dropped the bias terms for notational simplicity. For this reason, it is important to use 
nonlinear activation functions. 

In the early days of neural networks, a common choice was to use a sigmoid (logistic) function, 
which can be seen as a smooth approximation to the Heaviside function used in a perceptron: 

1 
~ l+e4 
However, as shown in Figure 13.2a, the sigmoid function saturates at 1 for large positive inputs, 
and at 0 for large negative inputs. Another common choice is the tanh function, which has a similar 
shape, but saturates at -1 and +1. See Figure 13.2b. 

In the saturated regimes, the gradient of the output wrt the input will be close to zero, so any 
gradient signal from higher layers will not be able to propagate back to earlier layers. This is called 
the vanishing gradient problem, and it makes it hard to train the model using gradient descent 
(see Section 13.4.2 for details). One of the keys to being able to train very deep models is to use 


non-saturating activation functions. Several different functions have been proposed. The most 
common is rectified linear unit or ReLU, proposed in [GBB11; KSH12]. This is defined as 


ReLU(a) = max(a, 0) = al (a > 0) (13.12) 


o(a) (13.11) 


The ReLU function simply “turns off” negative inputs, and passes positive inputs unchanged: see 
Figure 13.2b for a plot, and Section 13.4.3 for more details. 

When neural networks are used to represent functions defined on a continuous input space — 
such as points in time, f(t), or in 3d space, f(x,y,z) — they are often called neural implicit 
representations or coordinated based representations of the underlying signal. In such cases, 
it is often important to capture high frequencies to represent the signal faithfully. Unfortunately 
MLPs have an intrinsic bias to low frequency functions [Tan+20; RML22]. One simple solution is to 
use a sine function, sin(a), as the nonlinearity, instead of ReLU, as explained in [Sit-+20].? 


2. For some simple illustrations of the surprising power of the sine activation function for learning functions 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


428 Chapter 13. Neural Networks for Tabular Data 


Epoch Learning rate Activation Regularization Regularization rate Problem type 


000,254 0.1 ad Tanh d None X 0 X Classification 


FEATURES + — 2 HIDDEN LAYERS OUTPUT 


Which properties do Test loss 0.003 
you want to feed in? ie D Training loss 0.001 


2 neurons 


REGENERATE 


weight values 


[0 Show test data [] Discretize output 


Figure 18.3: An MLP with 2 hidden layers applied to a set of 2d points from 2 classes, shown in the top left 
corner. The visualizations associated with each hidden unit show the decision boundary at that part of the 
network. The final output is shown on the right. The input is x € R?, the first layer activations are zı € R*, the 
second layer activations are z2 € R?, and the final logit is ag € R, which is converted to a probability using the 
sigmoid function. This is a screenshot from the interactive demo at http: //playground. tensorflow. org. 


13.2.4 Example models 


MLPs can be used to perform classification and regression for many kinds of data. We give some 
examples below. 


13.2.4.1 MLP for classifying 2d data into 2 categories 


Figure 13.3 gives an illustration of an MLP with two hidden layers applied to a 2d input vector, 
corresponding to points in the plane, coming from two concentric circles. This model has the following 
form: 


plyļæ; @) = Ber(y|o(as)) 
a3 = ws 22 + bg 
z2 = p(W2z) + b2) 
zı = o(W,2 + bı) 


Here ag is the final logit score, which is converted to a probability via the sigmoid (logistic) function. 
The value a3 is computed by taking a linear combination of the 2 hidden units in layer 2, using 


on low dimensional input spaces, see https://nipunbatra.github.io/blog/posts/siren-paper.html and https: 
//nipunbatra.github.io/blog/posts/siren-paper.html. 
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Model: "sequential" 


Layer (type) Output Shape Param # 

flatten (Flatten) (one, 784)— 0o 
dense (Dense) (None, 128) 100480 
dense_1 (Dense) (None, 128) 16512 
dense_2 (Dense) (None, 10) 0 1290 O 


Total params: 118,282 
Trainable params: 118,282 
Non-trainable params: 0 


Table 13.2: Structure of the MLP used for MNIST classification. Note that 100,480 = (784+ 1) x 128, and 
16,512 = (128+ 1) x 128. mlp_ mnist_ tf.ipynb. 


a3 = wl zo +b. In turn, layer 2 is computed by taking a nonlinear combination of the 4 hidden units 
in layer 1, using z2 = y(W2z1+b2). Finally, layer 1 is computed by taking a nonlinear combination of 
the 2 input units, using zı = y(W1a”+5,). By adjusting the parameters, 0 = (W1, b1, W2, b2, w3, bs), 
to minimize the negative log likelihood, we can fit the training data very well, despite the highly 
nonlinear nature of the decision boundary. (You can find an interactive version of this figure at 
http://playground.tensorflow. org.) 


13.2.4.2 MLP for image classification 


To apply an MLP to image classification, we need to “flatten” the 2d input into 1d vector. We can 
then use a feedforward architecture similar to the one described in Section 13.2.4.1. For example, 
consider building an MLP to classifiy MNIST digits (Section 3.5.2). These are 28 x 28 = 784- 
dimensional. If we use 2 hidden layers with 128 units each, followed by a final 10 way softmax layer, 
we get the model shown in Table 13.2. 

We show some predictions from this model in Figure 13.4. We train it for just two “epochs” (passes 
over the dataset), but already the model is doing quite well, with a test set accuracy of 97.1%. 
Furthermore, the errors seem sensible, e.g., 9 is mistaken as a 3. Training for more epochs can further 
improve test accuracy. 

In Chapter 14 we discuss a different kind of model, called a convolutional neural network, which 
is better suited to images. This gets even better performance and uses fewer parameters, by 
exploiting prior knowledge about the spatial structure of images. By contrast, with an MLP, we 
can randomly shuffle (permute) the pixels without affecting the output (assuming we use the same 
random permutation for all inputs). 
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Figure 13.4: Results of applying an MLP (with 2 hidden layers with 128 units and 1 output layer with 10 
units) to some MNIST images (cherry picked to include some errors). Red is incorrect, blue is correct. (a) 
After 1 epoch of training. (b) After 2 epochs. Generated by mlp_ mmnist_ tf.ipynb. 


13.2.4.3 MLP for text classification 


To apply MLPs to text classification, we need to convert the variable-length sequence of words 
v1,...,Ur (where each v; is a one-hot vector of length V, where V is the vocabulary size) into a 
fixed dimensional vector æ. The easiest way to do this is as follows. First we treat the input as an 
unordered bag of words (Section 1.5.4.1), {v+}. The first layer of the model is a E x V embedding 
matrix W1, which converts each sparse V-dimensional vector to a dense £-dimensional embedding, 
e+ = Wv; (see Section 20.5 for more details on word embeddings). Next we convert this set of T 
&-dimensional embeddings into a fixed-sized vector using global average pooling, € = A D er. 
This can then be passed as input to an MLP. For example, if we use a single hidden layer, and a 
logistic output (for binary classification), we get 


p(y|x; 0) = Ber(y|o(w3h + b3)) (13.17) 
h = p(Woe + b2) (13.18) 
1 T 
E= Te (13.19) 
e, = Wi% (13.20) 


If we use a vocabulary size of V = 10,000, an embedding size of E = 16, and a hidden layer of size 
16, we get the model shown in Table 13.3. If we apply this to the IMDB movie review sentiment 
classification dataset discussed in Section 1.5.2.1, we get 86% on the validation set. 

We see from Table 13.3 that the model has a lot of parameters, which can result in overfitting, 
since the IMDB training set only has 25k examples. However, we also see that most of the parameters 
are in the embedding matrix, so instead of learning these in a supervised way, we can perform 
unsupervised pre-training of word embedding models, as we discuss in Section 20.5. If the embedding 
matrix W; is fixed, we just have to fine-tune the parameters in layers 2 and 3 for this specific labeled 
task, which requires much less data. (See also Chapter 19, where we discuss general techniques for 
training with limited labeled data.) 
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Model: "sequential" 


Layer (type) Output Shape Param # 

embedding (Embedding) (None, None, 16) 160000 
global_average_poolingld (Gl (None, 16) = OF | 
dense (Dense) = ~~ (None, 16) (sti(‘é«~X TZ” 
dense_1 (Dense) ~~ +~(None, 1) —(iti‘ a7 


Total params: 160,289 
Trainable params: 160,289 
Non-trainable params: 0 


Table 13.3: Structure of the MLP used for IMDB review classification. We use a vocabulary size of V = 10,000, 
an embedding size of E = 16, and a hidden layer of size 16. The embedding matrix Wi, has size 10,000 x 16, 
the hidden layer (labeled “dense”) has a weight matrix W2 of size 16 x 16 and bias bz of size 16 (note that 
16 x 16 + 16 = 272), and the final layer (labeled “dense_ 1”) has a weight vector w3 of size 16 and a bias b3 
of size 1. The global average pooling layer has no free parameters. mlp_imdb_ tf.ipynb. 
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Figure 13.5: Illustration of an MLP with a shared “backbone” and two output “heads”, one for predicting 
the mean and one for predicting the variance. From https: //brendanhasz. github. t0/ 2019/ 07/ 23/ 
bayesian-density-net. html. Used with kind permission of Brendan Hasz. 


13.2.4.4 MLP for heteroskedastic regression 


We can also use MLPs for regression. Figure 13.5 shows how we can make a model for heteroskedastic 
nonlinear regression. (The term “heteroskedastic” just means that the predicted output variance 
is input-dependent, as discussed in Section 2.6.3.) This function has two outputs which compute 
fu.(a) = E[y|x, 0] and f(a) = yY [ylæ, 0]. We can share most of the layers (and hence parameters) 
between these two functions by using a common “backbone” and two output “heads”, as shown in 
Figure 13.5. For the u head, we use a linear activation, y(a) = a. For the o head, we use a softplus 
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Figure 13.6: Illustration of predictions from an MLP fit using MLE to a 1d regression dataset with growing 
noise. (a) Output variance is input-dependent, as in Figure 13.5. (b) Mean is computed using same model as 
in (a), but output variance is treated as a fixed parameter a°, which is estimated by MLE after training, as in 
Section 11.2.3.6. Generated by mlp_1d_regression_hetero_ tfp.ipynb. 


activation, y(a) = o4 (a) = log(1 + e°). If we use linear heads and a nonlinear backbone, the overall 
model is given by 


p(ylax, 0) = N (ylw, f(a; Wshared); O4 (we f(x; Wshared))) (13.21) 


Figure 13.6 shows the advantage of this kind of model on a dataset where the mean grows linearly 
over time, with seasonal oscillations, and the variance increases quadratically. (This is a simple 
example of a stochastic volatility model; it can be used to model financial data, as well as the 
global temperature of the earth, which (due to climate change) is increasing in mean and in variance.) 
We see that a regression model where the output variance g? is treated as a fixed (input-independent) 
parameter will sometimes be underconfident, since it needs to adjust to the overall noise level, and 
cannot adapt to the noise level at each point in input space. 


13.2.5 The importance of depth 


One can show that an MLP with one hidden layer is a universal function approximator, meaning 
it can model any suitably smooth function, given enough hidden units, to any desired level of accuracy 
[HSW89; Cyb89; Hor91]. Intuitively, the reason for this is that each hidden unit can specify a half 
plane, and a sufficiently large combination of these can “carve up” any region of space, to which we 
can associate any response (this is easiest to see when using piecewise linear activation functions, as 
shown in Figure 13.7). 

However, various arguments, both experimental and theoretical (e.g., [Has87; Mon+14; Rag+17; 
Pog+17]), have shown that deep networks work better than shallow ones. The reason is that later 
layers can leverage the features that are learned by earlier layers; that is, the function is defined 
in a compositional or hierarchical way. For example, suppose we want to classify DNA strings, 
and the positive class is associated with the string *AA??CGCG??AA*, where ? is a wildcard denoting 
any single character, and * is a wildcard denoting any sequence of characters (possibly of length 0). 
Although we could fit this with a single hidden layer model, intuitively it will be easier to learn if the 
model first learns to detect the AA and CG “motifs” using the hidden units in layer 1, and then uses 
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Figure 13.7: A decomposition of R? into a finite set of linear decision regions produced by an MLP with 
ReLU activations with (a) one hidden layer of 25 hidden units and (b) two hidden layers. From Figure 1 of 
[HAB19]. Used with kind permission of Maksym Andriuschenko. 


these features to define a simple linear classifier in layer 2, analogously to how we solved the XOR 
problem in Section 13.2.1. 


13.2.6 The “deep learning revolution” 


Although the ideas behind DNNs date back several decades, it was not until the 2010s that they 
started to become very widely used. The first area to adopt these methods was the field of automatic 
speech recognition (ASR), based on breakthrough results in [Dah+11]. This approach rapidly became 
the standard paradigm, and was widely adopted in academia and industry [Hin+12]. 

However, the moment that got the most attention was when [KSH12] showed that deep CNNs 
could significantly improve performance on the challenging ImageNet image classification benchmark, 
reducing the error rate from 26% to 16% in a single year (see Figure 1.14b); this was a huge jump 
compared to the previous rate of progress of about 2% reduction per year. 

The “explosion” in the usage of DNNs has several contributing factors. One is the availability 
of cheap GPUs (graphics processing units); these were originally developed to speed up image 
rendering for video games, but they can also massively reduce the time it takes to fit large CNNs, 
which involve similar kinds of matrix-vector computations. Another is the growth in large labeled 
datasets, which enables us to fit complex function approximators with many parameters without 
overfitting. (For example, ImageNet has 1.3M labeled images, and is used to fit models that have 
millions of parameters.) Indeed, if deep learning systems are viewed as “rockets”, then large datasets 
have been called the fuel.’ 

Motivated by the outstanding empirical success of DNNs, various companies started to become 
interested in this technology. This had led to the development of high quality open-source software 
libraries, such as Tensorflow (made by Google), PyTorch (made by Facebook), and MXNet (made 
by Amazon). These libraries support automatic differentiation (see Section 13.3) and scalable 
gradient-based optimization (see Section 8.4) of complex differentiable functions. We will use some 


3. This popular analogy is due to Andrew Ng, who mentioned it in a keynote talk at the GPU Technology Conference 
(GTC) in 2015. His slides are available at https: //bit.ly/38RTxzH. 
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Figure 13.8: Illustration of two neurons connected together in a “circuit”. The output azon of the left neuron 
makes a synaptic connection with the dendrites of the cell on the right. Electrical charges, in the form of ion 
flows, allow the cells to communicate. From https: // en. wikipedia. org/wiki/ Neuron. Used with kind 
permission of Wikipedia author BruceBlaus. 


of these libraries in various places throughout the book to implement a variety of models, not just 
DNNs.* 
More details on the history of the “deep learning revolution” can be found in e.g., [Sej18; Met21]. 


13.2.7 Connections with biology 


In this section, we discuss the connections between the kinds of neural networks we have discussed 
above, known as artificial neural networks or ANNs, and real neural networks. The details on 
how real biological brains work are quite complex (see e.g., [Kan+12]), but we can give a simple 
“cartoon”. 

We start by considering a model of a single neuron. To a first approximation, we can say that 
whether neuron k fires, denoted by hg € {0,1}, depends on the activity of its inputs, denoted by 
x € R?, as well as the strength of the incoming connections, which we denote by wọ € R?. We 
can compute a weighted sum of the inputs using a, = w}.a. These weights can be viewed as “wires” 
connecting the inputs xq to neuron hx; these are analogous to dendrites in a real neuron (see 
Figure 13.8). This weighted sum is then compared to a threshold, by, and if the activation exceeds 
the threshold, the neuron fires; this is analogous to the neuron emitting an electrical output or 
action potential. Thus we can model the behavior of the neuron using h(x) = H(w),ax — bp), 
where H(a) = I (a > 0) is the Heaviside function. This is called the McCulloch-Pitts model of 
the neuron, and was proposed in 1943 [MP43]. 

We can combine multiple such neurons together to make an ANN. The result has sometimes been 


4. Note, however, that some have argued (see e.g., [BI19]) that current libraries are too inflexible, and put too much 
emphasis on methods based on dense matrix-vector multiplication, as opposed to more general algorithmic primitives. 
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Number of neurons (logarithmic scale) 


1950 1985 2000 2015 


Figure 18.9: Plot of neural network sizes over time. Models 1, 2, 3 and 4 correspond to the perceptron 
[Ros58], the adaptive linear unit [WH60] the neocognitron [Fuk80], and the first MLP trained by backprop 
[RHW86]. Approximate number of neurons for some living organisms are shown on the right scale (the sponge 
has 0 neurons), based on https: // en. wikipedia. org/wiki/List_of_animals_ by_ number_ of_ neurons. 
From Figure 1.11 of [GBC16]. Used with kind permission of Ian Goodfellow. 


viewed as a model of the brain. However, ANNs differs from biological brains in many ways, including 
the following: 


Most ANNs use backpropagation to modify the strength of their connections (see Section 13.3). 
However, real brains do not use backprop, since there is no way to send information backwards 
along an axon [Ben+15b; BS16; KH19]. Instead, they use local update rules for adjusting synaptic 
strengths. 


Most ANNs are strictly feedforward, but real brains have many feedback connections. It is believed 
that this feedback acts like a prior, which can be combined with bottom up likelihoods from the 
sensory system to compute a posterior over hidden states of the world, which can then be used for 
optimal decision making (see e.g., [Doy+07]). 


Most ANNs use simplified neurons consisting of a weighted sum passed through a nonlinearity, 
but real biological neurons have complex dendritic tree structures (see Figure 13.8), with complex 
spatio-temporal dynamics. 


Most ANNs are smaller in size and number of connections than biological brains (see Figure 13.9). 
Of course, ANNs are getting larger every week, fueled by various new hardware accelerators, 
such as GPUs and TPUs (tensor processing units), etc. However, even if ANNs match 
biological brains in terms of number of units, the comparison is misleading since the processing 
capability of a biological neuron is much higher than an artificial neuron (see point above). 


Most ANNs are designed to model a single function, such as mapping an image to a label, or a 
sequence of words to another sequence of words. By contrast, biological brains are very complex 
systems, composed of multiple specialized interacting modules, which implement different kinds 
of functions or behaviors such as perception, control, memory, language, etc (see e.g., [Sha88; 
Kan+12)). 
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Of course, there are efforts to make realistic models of biological brains (e.g., the Blue Brain 
Project [Mar06; Yon19]). However, an interesting question is whether studying the brain at this 
level of detail is useful for “solving AI”. It is commonly believed that the low level details of biological 
brains do not matter if our goal is to build “intelligent machines”, just as aeroplanes do not flap their 
wings. However, presumably “Als” will follow similar “laws of intelligence” to intelligent biological 
agents, just as planes and birds follow the same laws of aerodynamics. 

Unfortunately, we do not yet know what the “laws of intelligence” are, or indeed if there even are 
such laws. In this book we make the assumption that any intelligent agent should follow the basic 
principles of information processing and Bayesian decision theory, which is known to be the optimal 
way to make decisions under uncertainty (see Section 5.1). 

In practice, the optimal Bayesian approach is often computationally intractable. In the natural 
world, biological agents have evolved various algorithmic “shortcuts” to the optimal solution; this 
can explain many of the heuristics that people use in everyday reasoning [KST82; GTA00; Gri20]. 
As the tasks we want our machines to solve become harder, we may be able to gain insights from 
neuroscience and cognitive science for how to solve such tasks in an approximate way (see e.g., 
[MWK16; Has+17; Lak+17; HG21]). However, we should also bear in mind that AI/ML systems are 
increasingly used for safety-critical applications, in which we might want and expect the machine 
to do better than a human. In such cases, we may want more than just heuristic solutions that 
often work; instead we may want provably reliable methods, similar to other engineering fields (see 
Section 1.6.3 for further discussion). 


13.3 Backpropagation 


This section is coauthored with Mathieu Blondel. 


In this section, we describe the famous backpropagation algorithm, which can be used to 
compute the gradient of a loss function applied to the output of the network wrt the parameters 
in each layer. This gradient can then be passed to a gradient-based optimization algorithm, as we 
discuss in Section 13.4. 

The backpropagation algorithm was originally discovered in [BH69], and independently in [Wer74]. 
However, it was [RHW86] that brought the algorithm to the attention of the “mainstream” ML 
community. See the wikipedia page for more historical details. 

We initially assume the computation graph is a simple linear chain of stacked layers, as in an 
MLP. In this case, backprop is equivalent to repeated applications of the chain rule of calculus 
(see Equation (7.261)). However, the method can be generalized to arbitrary directed acyclic 
graphs (DAGs), as we discuss in Section 13.3.4. This general procedure is often called automatic 
differentiation or autodiff. 


13.3.1 Forward vs reverse mode differentiation 


Consider a mapping of the form o = f(a), where x € R” and o € R”. We assume that f is defined 
as a composition of functions: 


f= fao fso faohi (13.22) 


5. https: //en.wikipedia.org/wiki/Backpropagation#History 
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Figure 13.10: A simple linear-chain feedforward model with 4 layers. Here x is the input and o is the output. 
From [Blo20]. 


where fi: R” > R™, fo: R™ 3 R™, fg: R™? —> R™, and f4 : R™ — R™. The intermediate 
steps needed to compute o = f(x) are £2 = fie ), £3 = f2(æ£2), £4 = f3(a3), and o = f4(ax4). 
We can compute the Jacobian J f(x) = 9° € R™*” using the chain rule: 


ðo _ ðo 024 0x3 0x2 E a falx4) a fz(x3) Ofo(ax2) Ofi (x) (13.23) 
Ox x4 zz zə ðL x14 0x3 Ox Ox ' 
= J p,(x4)J fa (£3)I fa (£2)J f, (£) (13.24) 


We now discuss how to compute the Jacobian J (ax) efficiently. Recall that 


ð ð 
Of (a) bare Oe V file)" 
TORE- af a]: | a (E S 23.25) 
Ofm Ofm T 
orn e Otn V fm(T) 


where V f;(x)' € R!*” is the i’th row (for i = 1 : m) and 2f € R” is the j’th column (for j = 1 : n). 


Note that, in our notation, when m = 1, the gradient, a Vf (x), has the same shape as æ. It 
is eee a column vector, while J (a x) is a row vector. In this case, we therefore technically have 
V f(a) = Jsle)". 

We can extract the ith row from Js(æ) by using a vector Jacobian product (VJP) of the form 
elJp(x), where e; € R” is the unit basis vector. Similarly, we can extract the j’th column from 
J7(x) by using a Jacobian vector product (JVP) of the form J¢(x)e;, where e; € R”. This shows 
that the computation of J s(x) reduces to either n JVPs or m VJPs. 

If n < m, it is more efficient to compute J(a) for each column j = 1: n by using JVPs ina 
right-to-left manner. The right multiplication with a column vector v is 


Js (a)w = Jp, (es) Tp, (es) Ip 2) I (1) v (13.26) 


MmMxXmM3 m3xm2. M2XMı Mı XN nxi 


This can be computed using forward mode differentiation; see Algorithm 13.1 for the pseudocode. 
Assuming m = 1 and n = mı = m2 = m3, the cost of computing J p(x) is O(n?). 
If n > m (e.g., if the output is a scalar), it is more efficient to compute J¢(x) for each row i = 1 : m 


by using VJPs in a left-to-right manner. The left multiplication with a row vector ul is 


ul S¢(x) = ul Tp, (a4) J p (a3) J f, (£2) Fy, (x1) (13.27) 
im S22 a a 


mXxXm3 mMz3xX m2 M2XMı mı Xn 
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Algorithm 13.1: Foward mode differentiation 
1 £1 := £ 

2 vj :=ej ER” forg= lin 

3 for k=1: K do 

4 i Ek+1 = fk(£k) 

5 


vj := Jf, (®k)v; for j=1:n 


6 Return o = gg, [Jp(x)].,; =v; for j =1:n 


This can be done using reverse mode differentiation; see Algorithm 13.2 for the pseudocode. 
Assuming m = 1 and n = mı = mz = m3, the cost of computing J ¢(a) is O(n”). 


Algorithm 13.2: Reverse mode differentiation 


1 £1 := £ 
2 fork =1:K do 
3 | e+ = f(x) 


4 w := e; E R” fori=1:m 
5 for k= K : 1 do 
6 ul := ul Jp, (£k) fri=1:m 


7 Return o = eK +1, [Jp(x)|i,, = u; fori=1:m 


Both Algorithms 13.1 and 13.2 can be adapted to compute JVPs and VJPs against any collection 
PET m as respective inputs. Initializing these 
vectors to the standard basis is useful specifically for producing the complete Jacobian as output. 


peeey 


13.3.2 Reverse mode differentiation for multilayer perceptrons 


In the previous section, we considered a simple linear-chain feedforward model where each layer does 
not have any learnable parameters. In this section, each layer can now have (optional) parameters 
0,,...,04. See Figure 13.10 for an illustration. We focus on the case where the mapping has the 
form £ : R” —> R, so the output is a scalar. For example, consider £2 loss for a MLP with one hidden 
layer: 


£((2,4),8) = žily — Wae(Wi2)| (13.28) 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


13.3. Backpropagation 439 


we can represent this as the following feedforward model: 


fate 
a Wx 


We use the notation fk(£k, 0%) to denote the function at layer k, where x, is the previous output 
and 0, are the optional parameters for this layer. 
In this example, the final layer returns a scalar, since it corresponds to a loss function £ € R. 
Therefore it is more efficient to use reverse mode differentation to compute the gradient vectors. 
We first discuss how to compute the gradient of the scalar output wrt the parameters in each layer. 
We can easily compute the gradient wrt the predictions in the final layer 2 2. For the gradient wrt 
the parameters in the earlier layers, we can use the chain rule to get 


OL OL Ox4 
— = — 13.34 
003 Oxr4 063 ( ) 
OL OL 0x4 0x3 
= 13.35 
005 Ox4 0x3 002 ( ) 
OL OL x, 0x3 O 
= eal ae (13.36) 
00, Ox, 0x3 Lə 00, 
where each a = (Vo,L)' is a d,-dimensional gradient row vector, where dẹ is the number of 
parameters in layer k. We see that these can be computed recursively, by multiplying the gradient 


row vector at layer k by the Jacobian e which is an nk X nk- matrix, where nę is the number 


of hidden units in layer k. See Algorithm 13.3 for the pseudocode. 

This algorithm computes the gradient of the loss wrt the parameters at each layer. It also computes 
the gradient of the loss wrt the input, V£ € R”, where n is the dimensionality of the input. This 
latter quantity is not needed for parameter learning, but can be useful for generating inputs to a 
model (see Section 14.6 for some applications). 

All that remains is to specify how to compute the vector Jacobian product (VJP) of all supported 
layers. The details of this depend on the form of the function at each layer. We discuss some examples 
below. 


13.3.3 Vector-Jacobian product for common layers 


Recall that the Jacobian for a layer of the form f : R” — R™. is defined by 


(2) 0] 
aa [E o EN) [A 
Ya... Ya} \Vfale 


where V f;(x)' € R” is the i’th row (for i = 1 : m) and gt € R” is the j’th column (for j = 1: n). 


In this section, we describe how to compute the VJP u'J¢(ax) for common layers. 
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Algorithm 13.3: Backpropagation for an MLP with K layers 
1 // Forward pass 
2 £1 := £ 
3 for k=1: K do 
4 | @h+1 = fr(T£k, Ox) 
// Backward pass 


fork = K : 1 do 


.— T  OFk (er, 9x) 
| Jk = Uns O, 


T.— aT ĉ3fr(Ek;Ok) 
ug = Upp ga 


5 
6 UK41 := 1 
7 
8 


OTk 
10 // Output 
11 Return L = £K4+1, Vel = u, {Vo L = gk:k=1:K} 


13.3.3.1 Cross entropy layer 


Consider a cross-entropy loss layer taking logits x and target labels y as input, and returning a 
scalar: 


z = f(x) = CrossEntropy WithLogits(y, x) = — 5 Yc log(softmax(x).) = — 5 Yelogpe (13.38) 


Cc c 


where p = softmax(x) = ao ae are the predicted class probabilites, and y is the true distribution 
c/=1 


over labels (often a one-hot vector). The Jacobian wrt the input is 


I= = (p- y) €R'*? (13.39) 


To see this, assume the target label is class c. We have 


z = f(x) = —log(pe) = — log (<=) = log | X e” | -ze (13.40) 
j j 


Hence 
Oz fa) O evi ð 
m oe = =p,—I(i= 13.41 
Ox, Ox; "De da; ° ye" on e Pi (i=c) (13.41) 


If we define y = [I (i = c)], we recover Equation (13.39). Note that the Jacobian of this layer is a row 
vector, since the output is a scalar. 
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13.3.3.2 Elementwise nonlinearity 


Consider a layer that applies an elementwise nonlinearity, z = f(x) = p(x), so zi = p(ai). The (i, j) 
element of the Jacobian is given by 


Ozi = y(ai) ifi=j (13.42) 
Ox; 0 otherwise 


where y'(a) = (a). In other words, the Jacobian wrt the input is 
J= Àf = diag(y’(x)) (13.43) 


For an arbitrary vector u, we can compute u! J by elementwise multiplication of the diagonal elements 
of J with u. For example, if 


pla) = ReLU(a) = max(a, 0) (13.44) 
we have 
1, , JO a<0 
p (a) = $ i (13.45) 


The subderivative (Section 8.1.4.1) at a = 0 is any value in [0,1]. It is often taken to be 0. Hence 
ReLU’ (a) = H(a) (13.46) 


where H is the Heaviside step function. 


13.3.3.3 Linear layer 


Now consider a linear layer, z = f(x, W) = Wa, where W € R™*”, so x € R” and z € R™. We 
can compute the Jacobian wrt the input vector, J = oe E€ R™”X”, as follows. Note that 


zi = X > Wipe (13.47) 
k=1 


So the (i, j) entry of the Jacobian will be 


Oz; ð Š n o 
= — W; = Wir —— £k = Wi; 13.48 

Ər. a ktk D> k gg, F j ( ) 

7 1 k=1 k=1 í 
since ik =I (k = j). Hence the Jacobian wrt the input is 

Oz 

J=—= 13.4 
= (13.49) 
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The VJP between u! € R!*™ and J € R™*” is 
u < =u'WeR* (13.50) 


Now consider the Jacobian wrt the weight matrix, J = 2. This can be represented as am x (mxn) 
matrix, which is complex to deal with. So instead, let us focus on taking the gradient wrt a single 


weight, W;;. This is easier to compute, since tae is a vector. To compute this, note that 
Zk = 5 Wks (13.51) 
l=1 
zk = o 2 
= —_W,, = Ii@=kandj=l 13.52 
aw; dm aw; kl a (i and J ) ( ) 
Hence 
Oz T 
aw, = “+ 0 @; O +: 0) (13.53) 


where the non-zero entry occurs in location i. The VJP between u! € R!*™ and Be eRe (ex) 
can be represented as a matrix of shape 1 x (m x n). Note that 


a m ə 
u' am: =) u m =i; (13.54) 
K =1 a 
Therefore 
a 
le Sarl = ug" e R™ (13.55) 
1,: 


13.3.3.4 Putting it all together 


For an exercise that puts this all together, see Exercise 13.1. 


13.3.4 Computation graphs 


MLPs are a simple kind of DNN in which each layer feeds directly into the next, forming a chain 
structure, as shown in Figure 13.10. However, modern DNNs can combine differentiable components in 
much more complex ways, to create a computation graph, analogous to how programmers combine 
elementary functions to make more complex ones. (Indeed, some have suggested that “deep learning” 
be called “differentiable programming”.) The only restriction is that the resulting computation 
graph corresponds to a directed acyclic graph (DAG), where each node is a differentiable function 
of all its inputs. 
For example, consider the function 


f (x1, £2) = £26% y T1 + T26”! (13.56) 
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=Q EA EJ i 


X3 


X2 
X4 


Figure 13.11: An example of a computation graph with 2 (scalar) inputs and 1 (scalar) output. From [Blo20]. 


We can compute this using the DAG in Figure 13.11, with the following intermediate functions: 


T4 = fa(£2, £3) = T283 
= fs(£1, £4) = £1 + x4 
= felts) = yz 
= fr(w4, £6) = T426 
Note that we have numbered the nodes in topological order (parents before children). During the 


backward pass, since the graph is no longer a chain, we may need to sum gradients along multiple 
paths. For example, since x4 influences x5 and x7, we have 


ðo _ 00 0x5 ðo 0x7 


= 13.62 

Ox £5 Ox Ox7 Ou ( ) 
We can avoid repeated computation by working in reverse topological order. For example, 

ðo Ox7 
— = =I], 13.63 
Ox7 Ox7 ( ) 
ðo ðo 0x7 

= 13.64 
0x6 ~ Oar 0x6 ( ) 
o ðo O 

CaP (13.65) 


azs 0x6 Oxs 
ðo ðo Oz5 Oo Ox7 


Oar ees, T (13.66) 


In general, we use 


OO OS OER (13.67) 


where the sum is over all children k of node j, as shown in Figure 13.12. The eo 2 gradient vector 
has already been computed for each child k; this quantity is called the adjoint. This gets multiplied 
by the Jacobian pa of each child. 
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Figure 13.19: Computation graph for an MLP with input x, hidden layer h, output o, loss function L = &(0, y), 
an £2 regularizer s on the weights, and total loss J = L + s. From Figure 4.7.1 of [Zha+20]. Used with kind 
permission of Aston Zhang. 


The computation graph can be computed ahead of time, by using an API to define a static graph. 
(This is how Tensorflow 1 worked.) Alternatively, the graph can be computed “just in time”, by 
tracing the execution of the function on an input argument. (This is how Tensorflow eager mode 
works, as well as JAX and PyTorch.) The latter approach makes it easier to work with a dynamic 
graph, whose shape can change depending on the values computed by the function. 

Figure 13.13 shows a computation graph corresponding to an MLP with one hidden layer with 
weight decay. More precisely, the model computes the linear pre-activations z = Wa, the 
a oa h= ee ), the linear outputs o = Wh, the loss L = €(0,y), the regularizer 

4 (|W || + ||W)||3,), and the total loss J = L + s. 


13.4 Training neural networks 


In this section, we discuss how to fit DNNs to data. The standard approach is to use maximum 
likelihood estimation, by minimizing the NLL: 


L£(8) = — log p(D|@) = -5 log p(yn|£n; 9) (13.68) 
It is also common to add a regularizer (such as the negative log prior), as we discuss in Section 13.5. 
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In principle we can just use the backprop algorithm (Section 13.3) to compute the gradient of 
this loss and pass it to an off-the-shelf optimizer, such as those discussed in Chapter 8. (The Adam 
optimizer of Section 8.4.6.3 is a popular choice, due to its ability to scale to large datasets (by 
virtue of being an SGD-type algorithm), and to converge fairly quickly (by virtue of using diagonal 
preconditioning and momentum).) However, in practice this may not work well. In this section, 
we discuss various problems that may arise, as well as some solutions. For more details on the 
practicalities of training DNNs, see various other books, such as [HG20; Zha+20; Gér19]. 

In addition to practical issues, there are important theoretical issues. In particular, we note that 
the DNN loss is not a convex objective, so in general we will not be able to find the global optimum. 
Nevertheless, SGD can often find suprisingly good solutions. The research into why this is the case is 
still being conducted; see [Bah+20] for a recent review of some of this work. 


13.4.1 Tuning the learning rate 


It is important to tune the learning rate (step size), to ensure convergence to a good solution. We 
discuss this issue in Section 8.4.3. 


13.4.2 Vanishing and exploding gradients 


When training very deep models, the gradient tends to become either very small (this is called the 
vanishing gradient problem) or very large (this is called the exploding gradient problem), 
because the error signal is being passed through a series of layers which either amplify or diminish it 
[Hoc+01]. (Similar problems arise in RNNs on long sequences, as we explain in Section 15.2.6.) 

To explain the problem in more detail, consider the gradient of the loss wrt a node at layer I: 


OL OL Oz 
Oz, E OZ141 Oz 


= Jigi+i (13.69) 


re is the gradient at the next layer. If J; is 


constant across layers, it is clear that the contribution of the gradient from the final layer, gz, to 
layer 1 will be J’~'g;. Thus the behavior of the system depends on the eigenvectors of J. 

Although J is a real-valued matrix, it is not (in general) symmetric, so its eigenvalues and 
eigenvectors can be complex-valued, with the imaginary components corresponding to oscillatory 
behavior. Let be the spectral radius of J, which is the maximum of the absolute values of the 
eigenvalues. If this is greater than 1, the gradient can explode; if this is less than 1, the gradient can 
vanish. (Similarly, the spectral radius of W, connecting z; to z;41, determines the stability of the 
dynamical system when run in forwards mode.) 

The exploding gradient problem can be ameliorated by gradient clipping, in which we cap the 
magnitude of the gradient if it becomes too large, i.e., we use 


where J; = Jon is the Jacobian matrix, and gj11 = 


g' = min(1, Tel? (13.70) 


This way, the norm of g’ can never exceed c, but the vector is always in the same direction as g. 
However, the vanishing gradient problem is more difficult to solve. There are various solutions, 
such as the following: 
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Name Definition Range Reference 
Sigmoid o(a) = qe 0, 1] 
Hyperbolic tangent tanh(a) = 20(2a) — 1 —1, 1] 
Softplus o(a) = log(1 + e°) 0, œ] [GBB11] 
Rectified linear unit ReLU(a) = max(a, 0) 0, œ] [GBB11; KSH12] 
Leaky ReLU max(a, 0) + a min(a, 0) —oo,oo] [MHN13] 
Exponential linear unit max(a,0) + min(a(e*—1),0) [—co,oo] [CUH16] 
Swish aa(a) —oo,co] [RZL17| 
GELU a®(a) —oo,co] [HG16] 


Table 13.4: List of some popular activation functions for neural networks. 


Activation function 


Gradient of activation function 


2.0 


15 


1.0 


0.5 


0.0 


J: — sigmoid 
Š ==- leaky-relu 1.254 --- 
wei 

== swish 1.004 
— gelu 


— sigmoid 
leaky-relu 


(a) 


Figure 13.14: (a) Some popular activation functions. (b) Plot of their gradients. 


tion_fun_deriv_jax.ipynb. 


(b) 


Generated by activa- 


e Modify the the activation functions at each layer to prevent the gradient from becoming too large 


tion 13.4.4. 


or too small; see Section 13.4.3. 


e Modify the architecture so that the updates are additive rather than multiplicative; see Sec- 


e Modify the architecture to standardize the activations at each layer, so that the distribution of 


activations over the dataset remains constant during training; see Section 14.2.4.1. 


e Carefully choose the initial values of the parameters; see Section 13.4.5. 


13.4.3 Non-saturating activation functions 


In Section 13.2.3, we mentioned that the sigmoid activation function saturates at 0 for large negative 
inputs, and at 1 for large positive inputs. It turns out that the gradient signal in these regimes is 0, 
preventing backpropagation from working. 
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To see why the gradient vanishes, consider a layer which computes z = o(W x), where 


1 


GD (13.71) 


pla) = o(a) 


If the weights are initialized to be large (positive or negative), then it becomes very easy for a = Wa 
to take on large values, and hence for z to saturate near 0 or 1, since the sigmoid saturates, as shown 
in Figure 13.14a. Now let us consider the gradient of the loss wrt the inputs æ (from an earlier layer) 
and the parameters W. The derivative of the activation function is given by 


y’ (a) = o(a)(1 — o(a)) (13.72) 


See Figure 13.14b for a plot. In Section 13.3.3, we show that the gradient of the loss wrt the inputs 
is 


gg = W'S = W'2(1—2) (13.73) 


and the gradient of the loss wrt the parameters is 


OL 
< =62' =z(1-z)a! 13.74 
M (1-2) (13.74) 
Hence, if z is near 0 or 1, the gradients will go to 0. 

One of the keys to being able to train very deep models is to use non-saturating activation 
functions. Several different functions have been proposed: see Table 13.4 for a summary, and 
https://mlfromscratch.com/activation-functions-explained for more details. 


13.4.3.1 ReLU 


The most common is rectified linear unit or ReLU, proposed in [GBB11; KSH12]. This is defined 
as 


ReLU(a) = max(a, 0) = al (a > 0) (13.75) 


The ReLU function simply “turns off” negative inputs, and passes positive inputs unchanged. The 
gradient has the following form: 


ReLU' (a) = I (a > 0) (13.76) 


Now suppose we use this in a layer to compute z = ReLU(W z). In Section 13.3.3, we show that the 
gradient wrt the inputs has the form 


— = W'I (z > 0) (13.77) 


and wrt the parameters has the form 


oL 


aa le SO) a! (13.78) 
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Hence the gradient will not vanish, as long a z is positive. 

Unfortunately, if the weights are initialized to be large and negative, then it becomes very easy for 
(some components of) a = Wz to take on large negative values, and hence for z to go to 0. This 
will cause the gradient for the weights to go to 0. The algorithm will never be able to escape this 
situation, so the hidden units (components of z) will stay permanently off. This is called the “dead 
ReLU” problem [Lu+19]. 


13.4.3.2 Non-saturating ReLU 


The problem of dead ReLU’s can be solved by using non-saturating variants of ReLU. One alternate 
is the leaky ReLU, proposed in [MHN13]. This is defined as 


LReLU(a; a) = max(aa, a) (13.79) 


where 0 < a < 1. The slope of this function is 1 for positive inputs, and a for negative inputs, thus 
ensuring there is some signal passed back to earlier layers, even when the input is negative. See 
Figure 13.14b for a plot. If we allow the parameter a to be learned, rather than fixed, the leaky 
ReLU is called parametric ReLU [He+ 15]. 

Another popular choice is the ELU, proposed in [CUH16]. This is defined by 


a(e*—1) ifa<0 


: (13.80) 
a ifa>0O 


ELU(a; a) = 


This has the advantage over leaky ReLU of being a smooth function. See Figure 13.14 for plot. 
A slight variant of ELU, known as SELU (self-normalizing ELU), was proposed in [Kla+17]. This 
has the form 


SELU (a; a, A) = AELU(a; a) (13.81) 


Surprisingly, they prove that by setting a and A to carefully chosen values, this activation function 
is guaranteed to ensure that the output of each layer is standardized (provided the input is also 
standardized), even without the use of techniques such as batchnorm (Section 14.2.4.1). This can 
help with model fitting. 


13.4.3.3 Other choices 


As an alternative to manually discovering good activation functions, we can use blackbox optimization 
methods to search over the space of functional forms. Such an approach was used in [RZL17], where 
they discovered a function they call swish that seems to do well on some image classification 
benchmarks. It is defined by 


swish(a; 8) = ao (Ba) (13.82) 


(The same function, under the name SiLU (for Sigmoid Linear Unit), was independently proposed 
in [HG16].) See Figure 13.14 for plot. 


6. ELU only has a continuous first derivative if a = 1. 
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Figure 13.15: (a) Illustration of a residual block. (b) Illustration of why adding residual connections can help 
when training a very deep model. Adapted from Figure 14.16 of [Gér19]. 


Another popular activation function is GELU, which stands for “Gaussian Error Linear Unit” 
[HG16]. This is defined as follows: 


GELU (a) = að (a) (13.83) 


where ®(a) is the cdf of a standard normal: 
1 
(a) = Pr(N'(0,1) <a) = 5 (1 + erf(a/Vv2)) (13.84) 


We see from Figure 13.14 that this is not a convex or monontonic function, unlike most other 
activation functions. 

We can think of GELU as a “soft” version of ReLU, since it replaces the step function I (a > 0) 
with the Gaussian cdf, 6(a). Alternatively, the GELU can be motivated as an adaptive version of 
dropout (Section 13.5.4), where we multiply the input by a binary scalar mask, m ~ Ber(®(a)), 
where the probability of being dropped is given by 1 — ®(a). Thus the expected output is 


E[a] = (a) x a+ (1 — ®(a)) x 0 = a®(a) (13.85) 
We can approximate GELU using swish with a particular parameter setting, namely 


GELU(a) © ao(1.702a) (13.86) 


13.4.4 Residual connections 


One solution to the vanishing gradient problem for DNNs is to use a residual network or ResNet 
[He+16a]. This is a feedforward model in which each layer has the form of a residual block, defined 
by 


Fi(a@) = F(x) +a (13.87) 
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where F; is a standard shallow nonlinear mapping (e.g., linear-activation-linear). The inner F; 
function computes the residual term or delta that needs to be added to the input x to generate 
the desired output; it is often easier to learn to generate a small perturbation to the input than to 
directly predict the output. (Residual connections are usually used in conjunction with CNNs, as 
discussed in Section 14.3.4, but can also be used in MLPs.) 

A model with residual connections has the same number of parameters as a model without residual 
connections, but it is easier to train. The reason is that gradients can flow directly from the output 
to earlier layers, as sketched in Figure 13.15b. To see this, note that the activations at the output 
layer can be derived in terms of any previous layer l using 


Al 
ZL = z+ 5 Filzi; 01). (13.88) 


i=l 


We can therefore compute the gradient of the loss wrt the parameters of the lth layer as follows: 


se 7 a (13.89) 
— d ia (13.90) 
z 2a (5e) (13.91) 
= a + other terms (13.92) 


Thus we see that the gradient at layer l depends directly on the gradient at layer L in a way that is 
independent of the depth of the network. 


13.4.5 Parameter initialization 


Since the objective function for DNN training is non-convex, the way that we initialize the parameters 
of a DNN can play a big role on what kind of solution we end up with, as well as how easy the 
function is to train (i-e., how well information can flow forwards and backwards through the model). 
In the rest of this section, we present some common heuristic methods that are used for initializing 
parameters. 


13.4.5.1 Heuristic initialization schemes 


In [GB10], they show that sampling parameters from a standard normal with fixed variance can 
result in exploding activations or gradients. To see why, consider a linear unit with no activation 


function given by 0; = Da wij£j; suppose wi; ~ N (0, 0°), and E [x;] = 0 and Y [z;] = 7”, where 
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we assume x; are independent of wij. The mean and variance of the output is given by 


Nin Nin 


s [oi] = a i [wig ts] = L i [wiz] E [xj] = 0 (13.93) 
Y [o;] = E [o7] — (E [o;])? = > a [wae] -0= J ) [wz] 5 [z3] = Ninoy” (13.94) 


To keep the output variance from blowing up, we need to ensure nina? = 1 (or some other constant), 
where nin is the fan-in of a unit (number of incoming connections). 

Now consider the backwards pass. By analogous reasoning, we see that the variance of the gradients 
can blow up unless nouta? = 1, where nout is the fan-out of a unit (number of outgoing connections). 
To satisfy both requirements at once, we set E (Nin + Nout )o? = 1, or equivalently 


2 
fo (13.95) 


Nin + Nout 


This is known as Xavier initialization or Glorot initialization, named after the first author of 
[GB10]. 

A special case arises if we use o? = 1/nin; this is known as LeCun initialization, named after 
Yann LeCun, who proposed it in the 1990s. This is equivalent to Glorot initialization when nin = Nout- 
If we use o? = 2/nin, the method is called He initialization, named after Kaiming He, who proposed 
it in [He+15]. 

Note that it is not necessary to use a Gaussian distribution. Indeed, the above derivation just 
worked in terms of the first two moments (mean and variance), and made no assumptions about 
Gaussianity. For example, suppose we sample weights from a uniform distribution, w;; ~ Unif(—a, a). 

Although the above derivation assumes a linear output unit, the technique works well empirically 
even for nonlinear units. The best choice of initialization method depends on which activation 
function you use. For linear, tanh, logistic, and softmax, Glorot is recommended. For ReLU and 
variants, He is recommended. For SELU, LeCun is recommended. See e.g., [Gér19] for more heuristics, 
and e.g., [HDR19] for some theory. 


The mean is 0, and the variance is 0? = a?/3. Hence we should set a = 


13.4.5.2 Data-driven initializations 


We can also adopt a data-driven approach to parameter initialization. For example, [MM16] proposed 
a simple but effective scheme known as layer-sequential unit-variance (LSUV) initialization, 
which works as follows. First we initialize the weights of each (fully connected or convolutional) 
layer with orthonormal matrices, as proposed in [SMG14]. (This can be achieved by drawing from 
w ~ N(0,1), reshaping to w to a matrix W, and then computing an orthonormal basis using QR or 
SVD decomposition.) Then, for each layer l, we compute the variance w of the activations across a 
minibatch; we then rescale using W; := W/,/ Ju. This scheme can be viewed as an orthonormal 
initialization combined with batch normalization performed only on the first mini-batch. This is 
faster than full batch normalization, but can sometimes work just as well. 
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Figure 13.16: Calculation of minibatch stochastic gradient using data parallelism and two GPUs. From Figure 
12.5.2 of [Zha+20]. Used with kind permission of Aston Zhang. 


13.4.6 Parallel training 


It can be quite slow to train large models on large datasets. One way to speed this process up is 
to use specialized hardware, such as graphics processing units (GPUs) and tensor processing 
units (TPUs), which are very efficient at performing matrix-matrix multiplication. If we have 
multiple GPUs, we can sometimes further speed things up. There are two main approaches: model 
parallelism, in which we partition the model between machines, and data parallelism, in which 
each machine has its own copy of the model, and applies it to a different set of data. 

Model parallelism can be quite complicated, since it requires tight communication between machines 
to ensure they compute the correct answer. We will not discuss this further. Data parallelism is 
generally much simpler, since it is embarassingly parallel. To use this to speed up training, at each 
training step t, we do the following: 1) we partition the minibatch across the K machines to get Df; 
2) each machine k computes its own gradient, gë = VeLl(O;D*); 3) we collect all the local gradients 
on a central machine (e.g., device 0) and sum them using g+ = or gi; 4) we broadcast the summed 
gradient back to all devices, so g! = gą; 5) each machine updates its own copy of the parameters 
using 0¥ := OF — n.g*. See Figure 13.16 for an illustration and multi_gpu_training_jax.ipynb for 
some sample code. 

Note that steps 3 and 4 are usually combined into one atomic step; this is known as an all-reduce 
operation (where we use sum to reduce the set of (gradient) vectors into one). If each machine 
blocks until receiving the centrally aggregated gradient, g+, the method is known as synchronous 
training. This will give the same results as training with one machine (with a larger batchsize), 
only faster (assuming we ignore any batch normalization layers). If we let each machine update 
its parameters using its own local gradient estimate, and not wait for the broadcast to/from the 
other machines, the method is called asynchronous training. This is not guaranteed to work, 
since the different machines may get out of step, and hence will be updating different versions of the 
parameters; this approach has therefore been called hogwild training [Niu+11]. However, if the 
updates are sparse, so each machine “touches” a different part of the parameter vector, one can prove 
that hogwild training behaves like standard synchronous SGD. 
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13.5 Regularization 


In Section 13.4 we discussed computational issues associated with training (large) neural networks. 
In this section, we discuss statistical issues. In particular, we focus on ways to avoid overfitting. This 
is crucial, since large neural networks can easily have millions of parameters. 


13.5.1 Early stopping 


Perhaps the simplest way to prevent overfitting is called early stopping, which refers to the 
heuristic of stopping the training procedure when the error on the validation set starts to increase 
(see Figure 4.8 for an example). This method works because we are restricting the ability of the 
optimization algorithm to transfer information from the training examples to the parameters, as 
explained in [AS19]. 


13.5.2 Weight decay 


A common approach to reduce overfitting is to impose a prior on the parameters, and then use 
MAP estimation. It is standard to use a Gaussian prior for the weights V(w|0, aI) and biases, 
N (60, 671). This is equivalent to ¢2 regularization of the objective. In the neural networks literature, 
this is called weight decay, since it encourages small weights, and hence simpler models, as in ridge 
regression (Section 11.3). 


13.5.3 Sparse DNNs 


Since there are many weights in a neural network, it is often helpful to encourage sparsity. This 
allows us to perform model compression, which can save memory and time. To do this, we can 
use ¢, regularization (as in Section 11.4), or ARD (as in Section 11.7.7), or several other methods 
(see e.g., [Hoe+21; Bha+20] for recent reviews). As a simple example, Figure 13.17 shows a 5 layer 
MLP which has been fit to some 1d regression data using an £, regularizer on the weights. We see 
that the resulting graph topology is sparse. 

Despite the intuitive appeal of sparse topology, in practice these methods are not widely used, 
since modern GPUs are optimized for dense matrix multiplication, and there are few computational 
benefits to sparse weight matrices. However, if we use methods that encourage group sparsity, we can 
prune out whole layers of the model. This results in block sparse weight matrices, which can result in 
speedups and memory savings (see e.g., [Sca+17; Wen+16; MAV17; LUW17]). 


13.5.4 Dropout 


Suppose that we randomly (on a per-example basis) turn off all the outgoing connections from 
each neuron with probability p, as illustrated in Figure 13.18. This technique is known as dropout 
[Sri+14]. 

Dropout can dramatically reduce overfitting and is very widely used. Intuitively, the reason dropout 
works well is that it prevents complex co-adaptation of the hidden units. In other words, each unit 
must learn to perform well even if some of the other units are missing at random. This prevents the 
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Figure 18.17: (a) A deep but sparse neural network. The connections are pruned using bı regularization. At 
each level, nodes numbered 0 are clamped to 1, so their outgoing weights correspond to the offset/bias terms. 
(b) Predictions made by the model on the training set. Generated by sparse_ mlp.ipynb. 
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Figure 13.18: Illustration of dropout. (a) A standard neural net with 2 hidden layers. (b) An example of a 
thinned net produced by applying dropout with po = 0.5. Units that have been dropped out are marked with an 
x. From Figure 1 of [Sri+14]. Used with kind permission of Geoff Hinton. 


units from learning complex, but fragile, dependencies on each other.” A more formal explanation, 
in terms of Gaussian scale mixture priors, can be found in [NHLS19]. 

We can view dropout as estimating a noisy version of the weights, Orji = Wiji€ti, where en ~ 
Ber(1 — p) is a Bernoulli noise term. (So if we sample e;; = 0, then all of the weights going out of 
unit 7 in layer l — 1 into any j in layer l will be set to 0.) At test time, we usually turn the noise off. 


7. Geoff Hinton, who invented dropout, said he was inspired by a talk on sexual reproduction, which encourages genes 
to be individually useful (or at most depend on a small number of other genes), even when combined with random 
other genes. 
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To ensure the weights have the same expectation at test time as they did during training (so the 
input activation to the neurons is the same, on average), at test time we should use wy; = 01jiE [er]. 
For Bernoulli noise, we have E [e] = 1 — p, so we should multiply the weights by the keep probability, 
1 — p, before making predictions. 

We can, however, use dropout at test time if we wish. The result is an ensemble of networks, 
each with slightly different sparse graph structures. This is called Monte Carlo dropout [GG16; 
KG17], and has the form 


U| = 


S 
plylæ, D) = =X plyle, We? +b) (13.96) 
s=1 


where S is the number of samples, and we write We® to indicate that we are multiplying all 
the estimated weight matrices by a sampled noise vector. This can sometimes provide a good 
approximation to the Bayesian posterior predictive distribution p(y|a,D), especially if the noise rate 
is optimized [GHK17]. 


13.5.5 Bayesian neural networks 


Modern DNNs are usually trained using a (penalized) maximum likelihood objective to find a single 
setting of parameters. However, with large models, there are often many more parameters than data 
points, so there may be multiple possible models which fit the training data equally well, yet which 
generalize in different ways. It is often useful to capture the induced uncertainty in the posterior 
predictive distribution. This can be done by marginalizing out the parameters by computing 


plylæ, D) = | plylæ, @)p(8|D)a0 (13.97) 


The result is known as a Bayesian neural network or BNN. It can be thought of as an infinite 
ensemble of differently weight neural networks. By marginalizing out the parameters, we can avoid 
overfitting [Mac95]. Bayesian marginalization is challenging for large neural networks, but also can 
lead to significant performance gains [W120]. For more details on the topic of Bayesian deep 
learning, see the sequel to this book, [Mur23]. 


13.5.6 Regularization effects of (stochastic) gradient descent * 


Some optimization methods (in particular, second-order batch methods) are able to find “needles 
in haystacks”, corresponding to narrow but deep “holes” in the loss landscape, corresponding to 
parameter settings with very low loss. These are known as sharp minima, see Figure 13.19(right). 
From the point of view of minimizing the empirical loss, the optimizer has done a good job. However, 
such solutions generally correspond to a model that has overfit the data. It is better to find points 
that correspond to flat minima, as shown in Figure 13.19(left); such solutions are more robust and 
generalize better. To see why, note that flat minima correspond to regions in parameter space where 
there is a lot of posterior uncertainty, and hence samples from this region are less able to precisely 
memorize irrelevant details about the training set [AS17]. SGD often finds such flat minima by 
virtue of the addition of noise, which prevents it from “entering” narrow regions of the loss landscape 
(see e.g., [SL18]). This is called implicit regularization. It is also possible to explicitly encourage 
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Figure 18.19: Flat vs sharp minima. From Figures 1 and 2 of [HS97a]. Used with kind permission of Jürgen 
Schmidhuber. 
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Figure 13.20: Each curve shows how the loss varies across parameter values for a given minibatch. (a) A 
stable local minimum. (b) An unstable local minimum. Generated by sgd_minima_variance.ipynb. Adapted 
from https: //bit. ly/ 3wTc1Lé. 


SGD to find such flat minima, using entropy SGD [Cha+17], sharpness aware minimization 
[For+21], stochastic weight averaging (SWA) [Izm+ 18], and other related techniques. 

Of course, the loss landscape depends not just on the parameter values, but also on the data. Since 
we usually cannot afford to do full-batch gradient descent, we will get a set of loss curves, one per 
minibatch. If each one of these curves corresponds to a wide basin, as shown in Figure 13.20a, we 
are at a point in parameter space that is robust to perturbations, and will likely generalize well. 
However, if the overall wide basin is the result of averaging over many different narrow basins, as 
shown in Figure 13.20b, the resulting estimate will likely generalize less well. 

This can be formalized using the analysis in [Smi+21; BD21]. Specifically, they consider continuous 
time gradient flow which approximates the behavior of (S)GD. In [BD21], they consider full-batch 
GD, and show that the flow has the form w = —V~»Lop(w), where 


Lap(w) = £(w) + F||VL(w)||? (13.98) 


where £(w) is the original loss, € is the learning rate, and the second term is an implicit regularization 
term that penalizes solutions with large gradients (high curvature). 
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In [Smi+-21], they extend this analysis to the SGD case. They show that the flow has the form 
w= —Vwlseap (w), where 


Lscap(w) = L(w) + E YO |[V£r(w)|]? (13.99) 
k=1 


where m is the number of minibatches, and £;,(w) is the loss on the k’th such minibatch. Comparing 
this to the full-batch GD loss, we see 


Eset) = Lente) + x So |VLe(w) — VL(w)|[? (13.100) 
k=1 


The second term estimates the variance of the minibatch gradients, which is a measure of stability, 
and hence of generalization ability. 

The above analysis shows that SGD not only has computational advantages (since it is faster than 
full-batch GD or second-order methods), but also statistical advantages. 


13.6 Other kinds of feedforward networks * 


13.6.1 Radial basis function networks 


Consider a 1 layer neural net where the hidden layer is given by the feature vector 


p(x) = K(x, Hy),---,K(@, ug )] (13.101) 


where uy € X are a set of K centroids or exemplars, and K(æ, u) > 0 is a kernel function. 
We describe kernel functions in detail in Section 17.1. Here we just give an example, namely the 
Gaussian kernel 


1 
Kgama(2,0) È exp (— 3° lle~ all) (13.102) 
The parameter ø is known as the bandwidth of the kernel. Note that this kernel is shift invariant, 
meaning it is only a function of the distance r = ||a — e||2, so we can equivalently write this as 
Koss O = La (13.103) 
gauss r) = exp ae : 


This is therefore called a radial basis function kernel or RBF kernel. 
A 1 layer neural net in which we use Equation (13.101) as the hidden layer, with RBF kernels, is 
called an RBF network [BL88]. This has the form 


p(yl|a; 0) = p(y|w' (2) (13.104) 
where 0 = (m,w). If the centroids p are fixed, we can solve for the optimal weights w using 


(regularized) least squares, as discussed in Chapter 11. If the centroids are unknown, we can estimate 
them by using an unsupervised clustering method, such as K-means (Section 21.3). Alternatively, we 
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Figure 13.21: (a) xor truth table. (b) Fitting a linear logistic regression classifier using degree 10 polynomial 
expansion. (c) Same model, but using an RBF kernel with centroids specified by the 4 black crosses. Generated 
by logregXorDemo.ipynb. 


can associate one centroid per data point in the training set, to get y, = £n, where now K = N. 
This is an example of a non-parametric model, since the number of parameters grows (in this case 
linearly) with the amount of data, and is not independent of N. If K = N, the model can perfectly 
interpolate the data, and hence may overfit. However, by ensuring that the output weight vector 
w is sparse, the model will only use a finite subset of the input examples; this is called a sparse 
kernel machine, and will be discussed in more detail in Section 17.4.1 and Section 17.3. Another 
way to avoid overfitting is to adopt a Bayesian approach, by integrating out the weights w; this gives 
rise to a model called a Gaussian process, which will be discussed in more detail in Section 17.2. 


13.6.1.1 RBF network for regression 


We can use RBF networks for regression by defining p(y|x,@) = N(w? ¢(a),07). For example, 
Figure 13.22 shows a ld data set fit with K = 10 uniformly spaced RBF prototypes, but with the 
bandwidth ranging from small to large. Small values lead to very wiggly functions, since the predicted 
function value will only be non-zero for points x that are close to one of the prototypes mp. If the 
bandwidth is very large, the design matrix reduces to a constant matrix of 1’s, since each point is 
equally close to every prototype; hence the corresponding function is just a straight line. 


13.6.1.2 RBF network for classification 


We can use RBF networks for binary classification by defining p(y|a, 0) = Ber(o (wT ¢(a))). As an 
example, consider the data coming from the exclusive or function. This is a binary-valued function 
of two binary inputs. Its truth table is shown in Figure 13.21(a). In Figure 13.21(b), we have shown 
some data labeled by the xor function, but we have jittered the points to make the picture clearer.® 
We see we cannot separate the data even using a degree 10 polynomial. However, using an RBF 
kernel and just 4 prototypes easily solves the problem as shown in Figure 13.21(c). 


8. Jittering is a common visualization trick in statistics, wherein points in a plot/display that would otherwise land on 
top of each other are dispersed with uniform additive noise. 
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Figure 13.22: Linear regression using 10 equally spaced RBF basis functions in 1d. Left column: fitted 
function. Middle column: basis functions evaluated on a grid. Right column: design matrix. Top to bottom 
we show different bandwidths for the kernel function: o = 0.5, 10,50. Generated by linregRbfDemo.ipynb. 


13.6.2 Mixtures of experts 


When considering regression problems, it is common to assume a unimodal output distribution, such 
as a Gaussian or Student distribution, where the mean and variance is some function of the input, 
i.e., 


P(ylx) = N (yl fu (æ), diag(o+(fo(x)))) (13.105) 


where the f functions may be MLPs (possibly with some shared hidden units, as in Figure 13.5). 
However, this will not work well for one-to-many functions, in which each input can have multiple 
possible outputs. 

Figure 13.23a gives a simple example of such a function. We see that in the middle of the plot there 
are certain x values for which there are two equally probable y values. There are many real world 
problems of this form, e.g., 3d pose prediction of a person from a single image [Bo+08], colorization 
of a black and white image [Gua+17], predicting future frames of a video sequence [VT 17], etc. Any 
model which is trained to maximize likelihood using a unimodal output density — even if the model 
is a flexible nonlinear model, such as neural network — will work poorly on one-to-many functions 
such as these, since it will just produce a blurry average output. 

To prevent this problem of regression to the mean, we can use a conditional mixture model. 
That is, we assume the output is a weighted mixture of K different outputs, corresponding to different 
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Figure 13.28: (a) Some data from a one-to-many function. Horizontal azis is the input x, vertical axis is the 
target y = f(x). (b) The responsibilities of each expert for the input domain. (c) Prediction of each expert 
(colored lines) superimposed on the training data. (d) Overall prediction. Mean is red cross, mode is black 
square. Adapted from Figures 5.20 and 5.21 of [Bis06]. Generated by mixexpDemoOneToMany.ipynb. 


modes of the output distribution for each input æ. In the Gaussian case, this becomes 


K 
plylæ) = X plylæ, z = k)p(z = k|æ) (13.106) 
k=1 
pyle, z = k) = N (y| fu, k(æ), diag( fo,x(@))) (13.107) 
p(z = k|æ) = Cat(z|softmax(f,(ax))) (13.108) 


Here fu, predicts the mean of the k’th Gaussian, f,,, predicts its variance terms, and f, predicts 
which mixture component to use. This model is called a mixture of experts (MoE) |Jac+91; 
JJ94; YWG12; ME14]. The idea is that the k’th submodel p(y|x, z = k) is considered to be an 
“expert” in a certain region of input space. The function p(z = k|æ) is called a gating function, 
and decides which expert to use, depending on the input values. By picking the most likely expert 
for a given input x, we can “activate” just a subset of the model. This is an example of conditional 
computation, since we decide what expert to run based on the results of earlier computations from 
the gating network [Sha+17]. 

We can train this model using SGD, or using the EM algorithm (see Section 8.7.3 for details on 
the latter method). 
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Figure 13.24: Deep MOE with m experts, represented as a neural network. From Figure 1 of [CGG17]. Used 
with kind permission of Jacob Goldberger. 


13.6.2.1 Mixture of linear experts 


In this section, we consider a simple example in which we use linear regression experts and a linear 
classification gating function, i.e., the model has the form: 


p(yla, z = k, 0) = N (ylwie, of) (13.109) 
p(z = k|x, 0) = Cat(z|softmax,(Va)) (13.110) 


where softmax, is the k’th output from the softmax function. The individual weighting term 
p(z = kla) is called the responsibility for expert k for input x. In Figure 13.23b, we see how the 
gating networks softly partitions the input space amongst the K = 3 experts. 

Each expert p(y|x,z = k) corresponds to a linear regression model with different parameters. 
These are shown in Figure 13.23c. 

If we take a weighted combination of the experts as our output, we get the red curve in Figure 13.23d, 
which is clearly is a bad predictor. If instead we only predict using the most active expert (i.e., the 
one with the highest responsibility), we get the discontinuous black curve, which is a much better 
predictor. 


13.6.2.2 Mixture density networks 


The gating function and experts can be any kind of conditional probabilistic model, not just a linear 
model. If we make them both DNNs, then resulting model is called a mixture density network 
(MDN) [Bis94; ZS14] or a deep mixture of experts [CGGI17]. See Figure 13.24 for a sketch of 
the model. 
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Figure 13.25: A 2-level hierarchical mixture of experts as a neural network. The top gating network chooses 
between the left and right expert, shown by the large boxes; the left and right experts themselves choose between 
their left and right sub-experts. 


13.6.2.3 Hierarchical MOEs 


If each expert is itself an MoE model, the resulting model is called a hierarchical mixture of 
experts [JJ94]. See Figure 13.25 for an illustration of such a model with a two level hierarchy. 

An HME with L levels can be thought of as a “soft” decision tree of depth L, where each example 
is passed through every branch of the tree, and the final prediction is a weighted average. (We discuss 
decision trees in Section 18.1.) 


13.7 Exercises 


Exercise 13.1 [Backpropagation for a MLP] 
(Based on an exercise by Kevin Clark.) 


Consider the following classification MLP with one hidden layer: 


x = input € R? ( ) 
z= We +b; € R* ( ) 
h = ReLU(z) € R* (13.113) 
a=Uh+b € RF ( ) 
L = CrossEntropy(y, softmax(a)) € R ( ) 
where x € R?, bı € R*, W e€ R**?”, by € RO, U € R©**, where D is the size of the input, K is the 


number of hidden units, and C is the number of classes. Show that the gradients for the parameters and 
input are as follows: 
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_* OL Se T CxK 
Vvy£ = Ea es =di;h ER (13.116) 
aL \' ee 
TEN OL a T KxD 
Vwl Ee 522° ER (13.118) 
oe 
K 
Va, L = (se) =ô: ER (13.119) 
ac\' 
Wars ($<) =W' ô € R?” (13.120) 


where the gradients of the loss wrt the two layers (logit and hidden) are given by the following: 


sea e 2. c 
61=Val (3) (p—y)ER (13.121) 
OLN T K 
62=VzL= De =(V d:)OH(z)ER (13.122) 


where H is the Heaviside function. Note that, in our notation, the gradient (which has the same shape as the 
variable with respect to which we differentiate) is equal to the Jacobian’s transpose when the variable is a 
vector and to the first slice of the Jacobian when the variable is a matrix. 
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14.1 Introduction 


In Chapter 13, we discussed multilayered perceptrons (MLPs) as a way to learn functions mapping 
“unstructured” input vectors x € RP to outputs. In this chapter, we extend this to the case where the 
input x has 2d spatial structure. (Similar ideas apply to 1d temporal structure, or 3d spatio-temporal 
structure.) 

To see why it is not a good idea to apply MLPs directly to image data, recall that the core 
operation in an MLP at each hidden layer is computing the activations z = y(W«), where æ is the 
input to a layer, W are the weights, and y() is the nonlinear activation function. Thus the j’th 
element of the hidden layer has value z; = p(w} x). We can think of this inner product operation 
as comparing the input æ to a learned template or pattern wj; if the match is good (large positive 
inner product), the activation of that unit will be large (assuming a ReLU nonlinearity), signalling 
that the 7’th pattern is present in the input. 

However, this does not work well if the input is a variable-sized image, x € RW4#°, where W is the 
width, H is the height, and C is the number of input channels (e.g., C = 3 for RGB color). The 
problem is that we would need to learn a different-sized weight matrix W for every size of input 
image. In addition, even if the input was fixed size, the number of parameters needed would be 
prohibitive for reasonably sized images, since the weight matrix would have size (W x H x C) x D, 
where D is the number of outputs (hidden units). The final problem is that a pattern that occurs 
in one location may not be recognized when it occurs in a different location — that is, the model 
may not exhibit translation invariance — because the weights are not shared across locations (see 
Figure 14.1). 

To solve these problems, we will use convolutional neural networks (CNNs), in which we 
replace matrix multiplication with a convolution operation. We explain this in detail in Section 14.2, 
but the basic idea is to divide the input into overlapping 2d image patches, and to compare each 
patch with a set of small weight matrices, or filters, which represent parts of an object; this is 
illustrated in Figure 14.2. We can think of this as a form of template matching. We will learn 
these templates from data, as we explain below. Because the templates are small (often just 3x3 or 
5x5), the number of parameters is significantly reduced. And because we use convolution to do the 
template matching, instead of matrix multiplication, the model will be translationally invariant. This 
is useful for tasks such as image classification, where the goal is to classify if an object is present, 
regardless of its location. 

CNNs have many other applications besides image classification, as we will discuss later in this 
chapter. They can also be applied to 1d inputs (see Section 15.3) and 3d inputs; however, we mostly 
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Figure 14.1: Detecting patterns in 2d images using unstructured MLPs does not work well, because the method 
is not translation invariant. We can design a weight vector to act as a matched filter for detecting the 
desired cross-shape. This will give a strong response of 5 if the object is on the left, but a weak response of 1 
if the object is shifted over to the right. Adapted from Figure 7.16 of [SAV 20]. 


Figure 14.2: We can classify a digit by looking for certain discriminative features (image templates) occuring 
in the correct (relative) locations. From Figure 5.1 of [Cho17]. Used with kind permission of Francois Chollet. 


focus on the 2d case in this chapter. 


14.2 Common layers 
In this section, we discuss the basics of CNNs. 


14.2.1 Convolutional layers 


We start by describing the basics of convolution in 1d, and then in 2d, and then describe how they 
are used as a key component of CNNs. 


14.2.1.1 Convolution in 1d 


The convolution between two functions, say f,g : RP — R, is defined as 


[f ® gl(z y= fit f(u)g(z — u)du (14.1) 


Now suppose we replace the functions with finite-length vectors, which we can think of as functions 
defined on a finite set of points. For example, suppose f is evaluated at the points {—L,—L + 
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Figure 14.3: Discrete convolution of x = [1,2,3,4] with w = [5,6,7] to yield z = [5,16,34,52,45, 28]. We 
see that this operation consists of “flipping” w and then “dragging” it over x, multiplying elementwise, and 
adding up the results. 


Input Kernel Output 
BBR = Bis a 


Figure 14.4: 1d cross correlation. From Figure 15.8.2 of [Zha+20]. Used with kind permission of Aston 
Zhang. 


as 


1,...,0,1,...,£} to yield the weight vector (also called a filter or kernel) w-z = f(—L) up to 
wr = f(L). Now let g be evaluated at points {—N,...,N} to yield the feature vector x-y = g(—N) 
up to zy = g(N). Then the above equation becomes 


[w ® x] (i) = W-LTi+L +. + W_1Xj41 + WoXy + wyej-1 $e + WLTi-L (14.2) 


(We discuss boundary conditions (edge effects) later on.) We see that we “flip” the weight vector w 
(since indices of w are reversed), and then “drag” it over the x vector, summing up the local windows 
at each point, as illustrated in Figure 14.3. 

There is a very closely related operation, in which we do not flip w first: 


[w x æ] (i) = wri- +: + w_y aj + Wot; + Wi Figg + + WLU (14.3) 
This is called cross correlation; If the weight vector is symmetric, as is often the case, then cross 


correlation and convolution are the same. In the deep learning literature, the term “convolution” is 
usually used to mean cross correlation; we will follow this convention. 


We can also evaluate the weights w on domain {0,1,..., L — 1} and the features x on domain 
{0,1,...,N — 1}, to eliminate negative indices. Then the above equation becomes 
L-1 
[w @a](i) = X wutinn (14.4) 
u=0 


See Figure 14.4 for an example. 
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Output 


Figure 14.5: Illustration of 2d cross correlation. Generated by conv2d_jax.ipynb. Adapted from Figure 6.2.1 
of [Zha+20]. 


Single 
filter 


Figure 14.6: Convolving a 2d image (left) with a 3 x 3 filter (middle) produces a 2d response map (right). 
The bright spots of the response map correspond to locations in the image which contain diagonal lines sloping 
down and to the right. From Figure 5.3 of [Cho17]. Used with kind permission of Francois Chollet. 


14.2.1.2 Convolution in 2d 
In 2d, Equation (14.4) becomes 


H-1W-1 


[W @ X](7, j) = 5 5 Wu, vTi+u,j+v (14.5) 


u=0 v=0 


where the 2d filter W has size H x W. For example, consider convolving a 3 x 3 input X with a 
2 x 2 kernel W to compute a 2 x 2 output Y: 


ie Ah Tı T2 T3 
Y = ( i 2) ® | T4 T5 T6 (14.6) 
w3 W4 
£7 Tg Tg 


_ (wre, + Weta + w3%4 + wats) (Wizz + W2£3 + W3%5 + W4r6) (14.7) 
(wiz4 T W205 + W327 T wasze) (wız5 T W2%6 T W388 T wazo) f 
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See Figure 14.5 for a visualization of this process. 

We can think of 2d convolution as template matching, since the output at a point (i, 7) will 
be large if the corresponding image patch centered on (i,j) is similar to W. If the template W 
corresponds to an oriented edge, then convolving with it will cause the output heat map to “light up’ 
in regions that contain edges that match that orientation, as shown in Figure 14.6. More generally, 
we can think of convolution as a form of feature detection. The resulting output Y = W ® X is 
therefore called a feature map. 


? 


14.2.1.3 Convolution as matrix-vector multiplication 


Since convolution is a linear operator, we can represent it by matrix multiplication. For example, 
consider Equation (14.7). We can rewrite this as matrix-vector mutiplication by flattening the 2d 
matrix X into a ld vector x, and multiplying by a Toeplitz-like matrix C derived from the kernel 
W, as follows: 


Tı 
T2 
T3 
T4 
T5 (14.8) 
T6 
T7 
T8 
Tg 
Wi XL, + W2£2 + W3L4 + W4T5 
_ W1 Lg + W2X%3 + W3L5 + W4XG (14.9) 
W1%4 T W275 W327 W4%8 
W1X5 W2X6 W32%8 WATY 


We can recover the 2 x 2 output by reshaping the 4 x 1 vector y back to Y.! 

Thus we see that CNNs are like MLPs where the weight matrices have a special sparse structure, 
and the elements are tied across spatial locations. This implements the idea of translation invariance, 
and massively reduces the number of parameters compared to a weight matrix in a standard fully 
connected or dense layer, as used in MLPs. 


14.2.1.4 Boundary conditions and padding 


In Equation (14.7), we saw that convolving a 3 x 3 image with a 2 x 2 filter resulted in a 2 x 2 
output. In general, convolving a fn X fw filter over an image of size x, X £u produces an output of 
size (£h — fn +1) x (aw — fw + 1); this is called valid convolution, since we only apply the filter to 
“valid” parts of the input, i.e., we don’t let it “slide off the ends”. If we want the output to have the 
same size as the input, we can use zero-padding, which means we add a border of Os to the image, 
as illustrated in Figure 14.7. This is called same convolution. 


1. See conv2d_jax.ipynb for a demo. 
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Figure 14.7: Same-convolution (using zero-padding) ensures the output is the same size as the input. Adapted 
from Figure 8.3 of [SAV20]. 


fy=3 Zero Padding 


(a) (b) 
Figure 14.8: Illustration of padding and strides in 2d convolution. (a) We apply “same convolution’ to a 


5 x 7 input (with zero padding) using a 3 x 3 filter to create a 5 x 7 output. (b) Now we use a stride of 2, so 
the output has size 3 x 4. Adapted from Figures 14.3-14.4 of [Gér19]. 


In general, if the input has size £p X £w, we use a kernel of size fa X fw, we use zero padding on 
each side of size p and py, then the output has the following size [DV 16]: 


(tn + 2pn — fr + 1) X (fw + 2Pw — fw + 1) (14.10) 


For example, consider Figure 14.8a. We have p= 1, f = 3, a, = 5 and zy = 7, so the output has 
size 


(6+2-3+41)x(7+2-34+1)=5x7 (14.11) 
If we set 2p = f — 1, then the output will have the same size as the input. 
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Input Kernel Input Kernel Output 


Figure 14.9: Illustration of 2d convolution applied to an input with 2 channels. Generated by conv2d_jax.ipynb. 
Adapted from Figure 6.4.1 of [Zha+20]. 


14.2.1.5 Strided convolution 


Since each output pixel is generated by a weighted combination of inputs in its receptive field 
(based on the size of the filter), neighboring outputs will be very similar in value, since their inputs 
are overlapping. We can reduce this redundancy (and speedup computation) by skipping every s’th 
input. This is called strided convolution. This is illustrated in Figure 14.8b, where we convolve a 
5 x 7 image with a 3 x 3 filter with stride 2 to get a 3 x 4 output. 

In general, if the input has size £p X £w, we use a kernel of size fa X fw, we use zero padding on 
each side of size pa and pwu, and we use strides of size sp and Sw, then the output has the following 
size [DV 16]: 


[> + 2pn a fh +a] x = + 2pw =a fw Fes | 
Sh 


(14.12) 


Sw 


For example, consider Figure 14.8b, where we set the stride to s = 2. Now the output is smaller than 
the input, and has size 


5+2-3+2 7T+2-3+2 
H x 


J= 1f] xlf]=3x4 (14.13) 


14.2.1.6 Multiple input and output channels 


In Figure 14.6, the input was a gray-scale image. In general, the input will have multiple channels 
(e.g., RGB, or hyper-spectral bands for satellite images). We can extend the definition of convolution 
to this case by defining a kernel for each input channel; thus now W is a 3d weight matrix or tensor. 
We compute the output by convolving channel c of the input with kernel W. <, and then summing 
over channels: 


=l 1 


H w- 
aj=b+ 7 dU 


u=0 v=0 


C-1 
> Lsitu,sjtv,cWu,v,c (14.14) 
c=0 
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Figure 14.10: Illustration of a CNN with 2 convolutional layers. The input has 3 color channels. The feature 
maps at internal layers have multiple channels. The cylinders correspond to hypercolumns, which are feature 
vectors at a certain location. Adapted from Figure 14.6 of [Gér19]. 


where s is the stride (which we assume is the same for both height and width, for simplicity), and b 
is the bias term. This is illustrated in Figure 14.9. 

Each weight matrix can detect a single kind of feature. We typically want to detect multiple kinds 
of features, as illustrated in Figure 14.2. We can do this by making W into a 4d weight matrix. The 
filter to detect feature type d in input channel c is stored in W. ca. We extend the definition of 
convolution to this case as follows: 


AH-1W-1C-1 


Zijd = ba F 5 5 5 Tsi+u,sj+v,cWu,v,c,d (14.15) 


u=0 v=0 c=0 


This is illustrated in Figure 14.10. Each vertical cylindrical column denotes the set of output features 
at a given location, 2;,;,1:p; this is sometimes called a hypercolumn. Each element is a different 
weighted combination of the C features in the receptive field of each of the feature maps in the layer 
below.? 


14.2.1.7 1 x 1 (pointwise) convolution 


Sometimes we just want to take a weighted combination of the features at a given location, rather 
than across locations. This can be done using 1x1 convolution, also called pointwise convolution. 


2. In Tensorflow, a filter for 2d CNNs has shape (H, W, C, D), and a minibatch of feature maps has shape (batch-size, 
image-height, image-width, image-channels); this is called NHWC format. Other systems use different data layouts. 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


14.2. Common layers 473 


Input Kernel Output 


Figure 14.11: Mapping 3 channels to 2 using convolution with a filter of size 1 x 1 x 3 x 2. Adapted from 
Figure 6.4.2 of [Zha+20]. 


Output 


2 x 2 Max 


Pooling 


Figure 14.12: Illustration of maxpooling with a 2x2 filter and a stride of 1. Adapted from Figure 6.5.1 of 
[Zha+20]. 


This changes the number of channels from C to D, without changing the spatial dimensionality: 


G=1 


Zija = ba + 5 Ti, j,cW0,0,c,d (14.16) 
c=0 


This can be thought of as a single layer MLP applied to each feature column in parallel. 


14.2.2 Pooling layers 


Convolution will preserve information about the location of input features (modulo reduced resolution), 
a property known as equivariance. In some case we want to be invariant to the location. For 
example, when performing image classification, we may just want to know if an object of interest 
(e.g., a face) is present anywhere in the image. 

One simple way to achieve this is called max pooling, which just computes the maximum over 
its incoming values, as illustrated in Figure 14.12. An alternative is to use average pooling, which 
replaces the max by the mean. In either case, the output neuron has the same response no matter 
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Figure 14.193: A simple CNN for classifying images. Adapted from https: //blog. floydhub. com/ 
building-your-first-conunet/. 


where the input pattern occurs within its receptive field. (Note that we apply pooling to each feature 
channel independently.) 

If we average over all the locations in a feature map, the method is called global average pooling. 
Thus we can convert a H x W x D feature map into a 1 x 1 x D dimensional feature map; this can 
be reshaped to a D-dimensional vector, which can be passed into a fully connected layer to map it 
to a C-dimensional vector before passing into a softmax output. The use of global average pooling 
means we can apply the classifier to an image of any size, since the final feature map will always be 
converted to a fixed D-dimensional vector before being mapped to a distribution over the C classes. 


14.2.3 Putting it all together 


A common design pattern is to create a CNN by alternating convolutional layers with max pooling 
layers, followed by a final linear classification layer at the end. This is illustrated in Figure 14.13. 
(We omit normalization layers in this example, since the model is quite shallow.) This design pattern 
first appeared in Fukushima’s neocognitron [Fuk75], and was inspired by Hubel and Wiesel’s model 
of simple and complex cells in the human visual cortex [HW62]. In 1998 Yann LeCun used a similar 
design in his eponymous LeNet model [LeC+98], which used backpropagation and SGD to estimate 
the parameters. This design pattern continues to be popular in neurally-inspired models of visual 
object recognition [RP99], as well as various practical applications (see Section 14.3 and Section 14.5). 


14.2.4 Normalization layers 


The basic design in Figure 14.13 works well for shallow CNNs, but it can be difficult to scale it to 
deeper models, due to problems with vanishing or exploding gradients, as explained in Section 13.4.2. 
A common solution to this problem is to add extra layers to the model, to standardize the statistics 
of the hidden units (i.e., to ensure they are zero mean and unit variance), just like we do to the 
inputs of many models. We discuss various kinds of normalization layers below. 
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14.2.4.1 Batch normalization 


The most popular normalization layer is called batch normalization (BN) [IS15]. This ensures 
the distribution of the activations within a layer has zero mean and unit variance, when averaged 
across the samples in a minibatch. More precisely, we replace the activation vector Zņ (or sometimes 
the pre-activation vector an) for example n (in some layer) with Z,, which is computed as follows: 


in =YORn +B (14.17) 
2, = Zn ats (14.18) 
on+eE 
1 
Me=— > z (14.19) 
|5| zeB 
1 
ok = ae - He)? (14.20) 
| | z€B 


where B is the minibatch containing example n, pg is the mean of the activations for this batch’, o% 


is the corresponding variance, Z,, is the standardized activation vector, Žņ is the shifted and scaled 
version (the output of the BN layer), 8 and y are learnable parameters for this layer, and € > 0 is a 
small constant. Since this transformation is differentiable, we can easily pass gradients back to the 
input of the layer and to the BN parameters @ and y. 

When applied to the input layer, batch normalization is equivalent to the usual standardization 
procedure we discussed in Section 10.2.8. Note that the mean and variance for the input layer 
can be computed once, since the data is static. However, the empirical means and variances of 
the internal layers keep changing, as the parameters adapt. (This is sometimes called “internal 
covariate shift”.) This is why we need to recompute u and a? on each minibatch. 

At test time, we may have a single input, so we cannot compute batch statistics. The standard 
solution to this is as follows: after training, compute u, and ø? for layer | across all the examples 
in the training set (i.e. using the full batch), and then “freeze” these parameters, and add them 
to the list of other parameters for the layer, namely 6, and y;. At test time, we then use these 
frozen training values for u, and o7, rather than computing statistics from the test batch. Thus 
when using a model with BN, we need to specify if we are using it for inference or training. (See 
batchnorm_jax.ipynb for some sample code.) 

For speed, we can combine a frozen batch norm layer with the previous layer. In particular suppose 
the previous layer computes XW + b; combining this with BN gives ~@(XW +b — p)/o + B. If 
we define W’ = yO W/o and b' = yO(b — p)/o + B, then we can write the combined layers as 
XW’ +0’. This is called fused batchnorm. Similar tricks can be developed to speed up BN during 
training [Jun+19]. 

The benefits of batch normalization (in terms of training speed and stability) can be quite dramatic, 
especially for deep CNNs. The exact reasons for this are still unclear, but BN seems to make the 
optimization landscape significantly smoother [San+18b]. It also reduces the sensitivity to the 
learning rate [ALL18]. In addition to computational advantages, it has statistical advantages. In 


3. When applied to a convolutional layer, we average across spatial locations and across examples, but not across 
channels (so the length of yz is the number of channels). When applied to a fully connected layer, we just average 
across examples (so the length of p is the width of the layer). 
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Instance Norm Group Norm 


Figure 14.14: Illustration of different activation normalization methods for a CNN. Each subplot shows a 
feature map tensor, with N as the batch axis, C as the channel axis, and (H, W) as the spatial axes. The 
pixels in blue are normalized by the same mean and variance, computed by aggregating the values of these 
pixels. Left to right: batch norm, layer norm, instance norm, and group norm (with 2 groups of 3 channels). 
From Figure 2 of [WH18]. Used with kind permission of Kaiming He. 


particular, BN acts like a regularizer; indeed it can be shown to be equivalent to a form of approximate 
Bayesian inference [TAS18; Luo+19]. 

However, the reliance on a minibatch of data causes several problems. In particular, it can result 
in unstable estimates of the parameters when training with small batch sizes, although a more 
recent version of the method, known as batch renormalization |Iof17], partially addresses this. 
We discuss some other alternatives to batch norm below. 


14.2.4.2 Other kinds of normalization layer 


In Section 14.2.4.1 we discussed batch normalization, which standardizes all the activations within 
a given feature channel to be zero mean and unit variance. This can significantly help with training, 
and allow for a larger learning rate. (See batchnorm_jax.ipynb for some sample code.) 

Although batch normalization works well, it struggles when the batch size is small, since the 
estimated mean and variance parameters can be unreliable. One solution is to compute the mean 
and variance by pooling statistics across other dimensions of the tensor, but not across examples 
in the batch. More precisely, let z; refer to the tth element of a tensor; in the case of 2d images, 
the index i has 4 components, indicating batch, height, width and channel, i = (in,i4,iw,ic). We 
compute the mean and standard deviation for each index z; as follows: 


1 1 
li = isd 5 Zk, Oi = lm 5 (zk = wi)? +e (14.21) 


kESi kESi 


where S; is the set of elements we average over. We then compute 2; = (zi — i)/o; and 2; = Yeĉi + be, 
where c is the channel corresponding to index i. 

In batch norm, we pool over batch, height, width, so S; is the set of all location in the tensor 
that match the channel index of i. To avoid problems with small batches, we can instead pool over 
channel, height and width, but match on the batch index. This is known as layer normalization 
[BKH16]. (See layer_norm_jax.ipynb for some sample code.) Alternatively, we can have separate 
normalization parameters for each example in the batch and for each channel. This is known as 
instance normalization [UVL16]. 
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A natural generalization of the above methods is known as group normalization [WH18], where 
we pool over all locations whose channel is in the same group as it’s. This is illustrated in Figure 14.14. 
Layer normalization is a special case in which there is a single group, containing all the channels. 
Instance normalization is a special case in which there are C groups, one per channel. In [WH18], 
they show experimentally that it can be better (in terms of training speed, as well as training and test 
accuracies) to use groups that are larger than individual channels, but smaller than all the channels. 

More recently, [SK20] proposed filter response normalization which is an alternative to batch 
norm that works well even with a minibatch size of 1. The idea is to define each group as all locations 
with a single channel and batch sample (as in instance normalization), but then to just divide by 
the mean squared norm instead of standardizing. That is, if the input (for a given channel and 
batch entry) is z = Zp... € Rï, we compute 2 = z/Vv? + e, where v? = La Zbijc/N, and then 
Z=%7c2 + ße. Since there is no mean centering, the activations can drift away from 0, which can 
have detrimental effects, especially with ReLU activations. To compensate for this, the authors 
propose to add a thresholded linear unit at the output. This has the form y = max(æ, T), where 
T is a learnable offset. The combination of FRN and TLU results in good performance on image 
classification and object detection even with a batch size of 1. 


14.2.4.3 Normalizer-free networks 


Recently, [Bro+21] have proposed a method called normalizer-free networks, which is a way to 
train deep residual networks without using batchnorm or any other form of normalization layer. 
The key is to replace it with adaptive gradient clipping, as an alternative way to avoid training 
instabilities. That is, we use Equation (13.70), but adapt the clipping strength dynamically. The 
resulting model is faster to train, and more accurate, than other competitive models trained with 
batchnorm. 


14.3 Common architectures for image classification 


It is common to use CNNs to perform image classification, which is the task of estimating the function 
f : RYXWx* _, (0,11, where K is the number of input channels (e.g., K = 3 for RGB images), 
and C is the number of class labels. 

In this section, we briefly review various CNNs that have been developed over the years to 
solve image classification tasks. See e.g., [Kha+20] for a more extensive review of CNNs, and e.g., 
https://github.com/rwightman/pytorch-image-models for an up-to-date repository of code and 
models (in PyTorch). 


14.3.1 LeNet 


One of the earliest CNNs, created in 1998, is known as LeNet [LeC+98], named after its creator, 
Yann LeCun. It was designed to classify images of handwritten digits, and was trained on the MNIST 
dataset introduced in Section 3.5.2. The model is shown in Figure 14.15. (See also Figure 14.16a for a 
more compact representation of the model.) Some predictions of this model are shown in Figure 14.17. 
After just 1 epoch, the test accuracy is already 98.8%. By contrast, the MLP in Section 13.2.4.2 had 
an accuracy of 95.9% after 1 epoch. More rounds of training can further increase accuracy to a point 
where performance is indistinguishable from label noise. (See lenet_jax.ipynb for some sample code.) 
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Figure 14.15: LeNet5, a convolutional neural net for classifying handwritten digits. From Figure 6.6.1 of 
[Zha+20]. Used with kind permission of Aston Zhang. 
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Figure 14.16: (a) LeNet5. We assume the input has size 1 x 28 x 28, as is the case for MNIST. From Figure 
6.6.2 of [Zha+20]. Used with kind permission of Aston Zhang. (b) AlexNet. We assume the input has size 
3 x 224 x 224, as is the case for (cropped and rescaled) images from ImageNet. From Figure 7.1.2 of [Zha+20]. 
Used with kind permission of Aston Zhang. 
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Figure 14.17: Results of applying a CNN to some MNIST images (cherry picked to include some errors). Red is 
incorrect, blue is correct. (a) After 1 epoch of training. (b) After 2 epochs. Generated by cnn_mnist_tf.ipynb. 


Of course, classifying isolated digits is of limited applicability: in the real world, people usually 
write strings of digits or other letters. This requires both segmentation and classification. LeCun 
and colleagues devised a way to combine convolutional neural networks with a model similar to a 
conditional random field to solve this problem. The system was deployed by the US postal service. 
See [LeC+98] for a more detailed account of the system. 


14.3.2 AlexNet 


Although CNNs have been around for many years, it was not until the paper of [KSH12] in 2012 
that mainstream computer vision researchers paid attention to them. In that paper, the authors 
showed how to reduce the (top 5) error rate on the ImageNet challenge (Section 1.5.1.2) from the 
previous best of 26% to 15%, which was a dramatic improvement. This model became known as 
AlexNet model, named after its creator, Alex Krizhevsky. 

Figure 14.16b(b) shows the architecture. It is very similar to LeNet, shown in Figure 14.16a, with 
the following differences: it is deeper (8 layers of adjustable parameters (i.e., excluding the pooling 
layers) instead of 5); it uses ReLU nonlinearities instead of tanh (see Section 13.2.3 for why this is 
important); it uses dropout (Section 13.5.4) for regularization instead of weight decay; and it stacks 
several convolutional layers on top of each other, rather than strictly alternating between convolution 
and pooling. Stacking multiple convolutional layers together has the advantage that the receptive 
fields become larger as the output of one layer is fed into another (for example, three 3 x 3 filters in 
a row will have a receptive field size of 7 x 7). This is better than using a single layer with a larger 
receptive field, since the multiple layers also have nonlinearities in between. Also, three 3 x 3 filters 
have fewer parameters than one 7 x 7. 

Note that AlexNet has 60M free parameters (which is much more than the 1M labeled examples), 
mostly due to the three fully connected layers at the output. Fitting this model relied on using two 
GPUs (due to limited memory of GPUs at that time), and is widely considered an engineering tour 
de force.* Figure 1.14a shows some predictions made by the model on some images from ImageNet. 


4. The 3 authors of the paper (Alex Krizhevsky, Ilya Sutskever and Geoff Hinton) were subsequently hired by Google, 
although Ilya left in 2015, and Alex left in 2017. For more historical details, see https: //en.wikipedia.org/wiki/ 
AlexNet. Note that AlexNet was not the first CNN implemented on a GPU; that honor goes to a group at Microsoft 
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Concatenation 


Figure 14.18: Inception module. The 1 x 1 convolutional layers reduce the number of channels, keeping the 
spatial dimensions the same. The parallel pathways through convolutions of different sizes allows the model to 
learn which filter size to use for each layer. The final depth concatenation block combines the outputs of all 
the different pathways (which all have the same spatial size). From Figure 7.4.1 of [Zha+20]. Used with kind 
permission of Aston Zhang. 
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Figure 14.19: GoogLeNet (slightly simplified from the original). Input is on the left. From Figure 7.4.2 of 
[Zha+20]. Used with kind permission of Aston Zhang. 
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14.3.3 GoogLeNet (Inception) 


Google who developed a model known as GoogLeNet [Sze+15a]. (The name is a pun on Google and 
LeNet.) The main difference from earlier models is that GoogLeNet used a new kind of block, known 
as an inception blockř, that employs multiple parallel pathways, each of which has a convolutional 
filter of a different size. See Figure 14.18 for an illustration. This lets the model learn what the 
optimal filter size should be at each level. The overall model consists of 9 inception blocks followed by 
global average pooling. See Figure 14.19 for an illustration. Since this model first came out, various 
extensions were proposed; details can be found in [IS15; Sze+15b; SIV17]. 
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Figure 14.20: A residual block for a CNN. Left: standard version. Right: version with 1x1 convolution, to 
allow a change in the number of channels between the input to the block and the output. From Figure 7.6.3 of 
[Zha+20]. Used with kind permission of Aston Zhang. 
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Figure 14.21: The ResNet-18 architecture. Each dotted module is a residual block shown in Figure 14.20. 
From Figure 7.6.4 of [Zha+20]. Used with kind permission of Aston Zhang. 


14.3.4 ResNet 


The winner of the 2015 ImageNet classification challenge was a team at Microsoft, who proposed a 
model known as ResNet [He+16a]. The key idea is to replace v4; = Fı(xı) with 


Lit. = p(x + Fi(xi)) (14.22) 


This is known as a residual block, since F; only needs to learn the residual, or difference, between 
input and output of this layer, which is a simpler task. In [He+16a], F has the form conv-BN-relu- 
conv-BN, where conv is a convolutional layer, and BN is a batch norm layer (Section 14.2.4.1). See 
Figure 14.20(left) for an illustration. 


[CPS06], who got a 4x speedup over CPUs, and then [Cir-+11], who got a 60x speedup. 
5. This term comes from the movie Inception, in which the phrase “We need to go deeper” was uttered. This became a 
popular meme in 2014. 
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We can ensure the spatial dimensions of the output Fı(x,) of the convolutional layer match those 
of the input x; by using padding. However, if we want to allow for the output of the convolutional 
layer to have a different number of channels, we need to add 1 x 1 convolution to the skip connection 
on xı. See Figure 14.20(right) for an illustration. 

The use of residual blocks allows us to train very deep models. The reason this is possible is that 
gradient can flow directly from the output to earlier layers, via the skip connections, for reasons 
explained in Section 13.4.4. 

In [He+16a] they trained a 152 layer ResNet on ImageNet. However, it is common to use shallower 
models. For example, Figure 14.21 shows the ResNet-18 architecture, which has 18 trainable layers: 
there are 2 3x3 conv layers in each residual block, and there are 8 such blocks, with an initial 7x7 
conv (stride 2) and a final fully connected layer. Symbolically, we can define the model as follows: 


(Conv : BN : Max) : R: R): R : R) : (RP : RB) : R? : R) : Avg : FC 


where R is a residual block, R’ is a residual block with skip connection (due to the change in the 
number of channels) with stride 2, FC is fully connected (dense) layer, and : denotes concatenation. 
Note that the input size gets reduced spatially by a factor of 2° = 32 (factor of 2 for each R’ block, 
plus the initial Conv-7x7(2) and Max-pool), so a 224x224 images becomes a 7x7 image before going 
into the global average pooling layer. 

Some code to fit these models can be found online.° 

In [He+16b], they showed how a small modification of the above scheme allows us to train 
models with up to 1001 layers. The key insight is that the signal on the skip connections is 
still being attenuated due to the use of the nonlinear activation function after the addition step, 
Lai = Y(xı + F(a)). They showed that it is better to use 


Ziy = 2 + (Fi (a)) (14.23) 


This is called a preactivation resnet or PreResnet for short. Now it is very easy for the network 
to learn the identity function at a given layer: if we use ReLU activations, we just need to ensure 
that Fı(xı) = 0, which we can do by setting the weights and biases to 0. 

An alternative to using a very deep model is to use a very “wide” model, with lots of feature 
channels per layer. This is the idea behind the wide resnet model [ZK16], which is quite popular. 


14.3.5 DenseNet 


In a residual net, we add the output of each function to its input. An alternative approach would 
be to concatenate the output with the input, as illustrated in Figure 14.22a. If we stack a series of 
such blocks, we can get an architecture similar to Figure 14.22b. This is known as a DenseNets 
[Hua+17a], since each layer densely depends on all previous layers. Thus the overall model is 
computing a function of the form 


T > [æ, fı(æ), fo(x, fi(a)), fa(x, fila), f2(æ, fi(x))), se J (14.24) 


6. The notebook resnet_jax.ipynb fits this model on FashionMNIST. The notebook cifarl0_cnn_lightning.ipynb fits 
it on the more challenging CIFAR-10 dataset. The latter code uses various tricks to achieve 89% top-1 accuracy on the 
CIFAR test set after 20 training epochs. The tricks are data augmentation (Section 19.1), consisting of random crops 
and horizontal flips, and to use one-cycle learning rate schedule (Section 8.4.3). If you use 50 epochs, and stochastic 
weight averaging (Section 8.4.4), you can get to ~ 94% accuracy. 
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Figure 14.22: (a) Left: a residual block adds the output to the input. Right: a densenet block concatenates the 
output with the input. (b) Illustration of a densenet. From Figures 7.7.1-7.7.2 of [Zha+20]. Used with kind 
permission of Aston Zhang. 


The dense connectivity increases the number of parameters, since the channels get stacked depthwise. 
We can compensate for this by adding 1 x 1 convolution layers in between. We can also add pooling 
layers with a stride of 2 to reduce the spatial resolution. (See densenet _jax.ipynb for some sample 
code.) 

DenseNets can perform better than ResNets, since all previously computed features are directly 
accessible to the output layer. However, they can be more computationally expensive. 


14.3.6 Neural architecture search 


We have seen how many CNNs are fairly similar in their design, and simply rearrange various building 
blocks (such as convolutional or pooling layers) in different topologies, and adjust various parameter 
settings (e.g., stride, number of channels, or learning rate). Indeed, the recent ConvNeXt model 
of [Liu+22] — which, at the time of writing (April 2022) is considered the state of the art CNN 
architecture for a wide variety of vision tasks — was created by combining multiple such small 
improvements on top of a standard ResNet architecture. 

We can automate this design process using blackbox (derivative free) optimization methods to find 
architectures that minimize the validation loss. This is called AutoML (see e.g., [HKV19]). In the 
context of neural nets, it is called neural architecture search or NAS [EMH19]. 

When performing NAS, we can optimize for multiple objectives at the same time, such as accuracy, 
model size, training or inference speed, etc (this is how EfficientNetv2 is created [TL21]). The 
main challenge arises due to the expense of computing the objective (since it requires training each 
candidate point in model space). One way to reduce the number of calls to the objective function 
is to use Bayesian optimization (see e.g., [WNS19]). Another approach is to create differentiable 
approximations to the loss (see e.g., [LSY19; Wan-+21]), or to convert the architecture into a kernel 
function (using the neural tangent kernel method, Section 17.2.8), and then to analyze properties of 
its eigenvalues, which can predict performance without actually training the model [CGW21]. The 
field of NAS is very large and still growing. See [EMH19] for a more thorough review. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


484 Chapter 14. Neural Networks for Images 


dilation=1 dilation=2 dilation=3 


Figure 14.23: Dilated convolution with a 3x3 filter using rate 1, 2 and 8. From Figure 1 of [Cui+19]. Used 
with kind permission of Ximin Cut. 


14.4 Other forms of convolution * 


We discussed the basics of convolution in Section 14.2. In this section, we discuss some extensions, 
which are needed for applications such as image segmentation and image generation. 


14.4.1 Dilated convolution 


Convolution is an operation that combines the pixel values in a local neighborhood. By using striding, 
and stacking many layers of convolution together, we can enlarge the receptive field of each neuron, 
which is the region of input space that each neuron responds to. However, we would need many 
layers to give each neuron enough context to cover the entire image (unless we used very large filters, 
which would be slow and require too many parameters). 

As an alternative, we can use convolution with holes [Mal99], sometimes known by the French 
term a trous algorithm, and recently renamed dilated convolution [YK16]. This method simply 
takes every r’th input element when performing convolution, where r is known as the rate or 
dilation factor. For example, in 1d, convolving with filter w using rate r = 2 is equivalent to regular 
convolution using the filter w = [w 1,0, w2, 0, w3], where we have inserted Os to expand the receptive 
field (hence the term “convolution with holes”). This allows us to get the benefit of increased receptive 
fields without increasing the number of parameters or the amount of compute. See Figure 14.23 for 
an illustration. 

More precisely, dilated convolution in 2d is defined as follows: 


AH-1W-1C-1 


Zi jd = ba + 5 5 5 Ti+ru,j+rv,cWu,v,c,d (14.25) 


u=0 v=0 c=0 


where we assume the same rate r for both height and width, for simplicity. Compare this to 
Equation (14.15), where the stride parameter uses Xsi+u,sj+v,c- 


14.4.2 Transposed convolution 


In convolution, we reduce from a large input X to a small output Y by taking a weighted combination 
of the input pixels and the convolutional kernel K. This is easiest to explain in code: 
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Input Kernel 


Figure 14.24: Transposed convolution with 2x2 kernel. From Figure 13.10.1 of [Zha+20]. Used with kind 
permission of Aston Zhang. 


def conv(X, K): 
h, w = K.shape 
Y = zeros((X.shape[0] - h + 1, X.shape[1] - w + 1)) 
for i in range(Y.shape[0]): 
for j in range(Y.shape[1]): 
Yi, j] = (X[i:i + h, j:j + w] * K).sum() 
return Y 


In transposed convolution, we do the opposite, in order to produce a larger output from a smaller 
input: 


def trans_conv(X, K): 
h, w = K.shape 
Y = zeros((X.shape[0] + h - 1, X.shape[1] + w - 1)) 
for i in range(X.shape[0]): 
for j in range(X.shape[1]): 
Y[i:i + h, j:j + w] += XG, j] * K 
return Y 


This is equivalent to padding the input image with (h — 1, w — 1) Os (on the bottom right), where 
(h, w) is the kernel size, then placing a weighted copy of the kernel on each one of the input locations, 
where the weight is the corresponding pixel value, and then adding up. This process is illustrated in 
Figure 14.24. We can think of the kernel as a “stencil” that is used to generate the output, modulated 
by the weights in the input. 

The term “transposed convolution” comes from the interpretation of convolution as matrix multi- 
plication, which we discussed in Section 14.2.1.3. If W is the matrix derived from kernel K using 
the process illustrated in Equation (14.9), then one can show that Y = transposed-conv(X, K) is 
equivalent to Y = reshape(W' vec(X)). See transposed _conv_jax.ipynb for a demo. 

Note that transposed convolution is also sometimes called deconvolution, but this is an incorrect 
usage of the term: deconvolution is the process of “undoing” the effect of convolution with a known 
filter, such as a blur filter, to recover the original input, as illustrated in Figure 14.25. 
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Convolution DeConvolution Transposed Convolution 


Figure 14.25: Convolution, deconvolution and transposed convolution. Here s is the stride and p is the 
padding. From https: // tinyurl. com/ynzcasut. Used with kind permission of Aqeel Anwar. 


14.4.3 Depthwise separable convolution 


Standard convolution uses a filter of size H x W x C x D, which requires a lot of data to learn and a 
lot of time to compute with. A simplification, known as depthwise separable convolution, first 
convolves each input channel by a corresponding 2d filter w, and then maps these C channels to D 
channels using 1 x 1 convolution w’: 


C-1 H-1W-1 
tiga = bat D> whe > 5 ristcte] (14.26) 
c=0 


u=0 v=0 


See Figure 14.26 for an illustration. 

To see the advantage of this, let us consider a simple numerical example.’ Regular convolution of a 
12 x 12 x 3 input with a 5 x 5 x 3 x 256 filter gives a 8 x 8 x 256 output (assuming valid convolution: 
12-5+1=8), as illustrated in Figure 14.13. With separable convolution, we start with 12 x 12 x 3 input, 
convolve with a 5 x 5 x 1 x 1 filter (across space but not channels) to get 8 x 8 x 3, then pointwise 
convolve (across channels but not space) with a 1 x 1 x 3 x 256 filter to get a 8 x 8 x 256 output. So 
the output has the same size as before, but we used many fewer parameters to define the layer, and 
used much less compute. For this reason, separable convolution is often used in lightweight CNN 
models, such as the MobileNet model [How-+17; San+18a] and other edge devices. 


14.5 Solving other discriminative vision tasks with CNNs * 


In this section, we briefly discuss how to tackle various other vision tasks using CNNs. Each task 
also introduces a new architectural innovation to the library of basic building blocks we have already 
seen. More details on CNNs for computer vision can be found in e.g., [Bro19]. 


14.5.1 Image tagging 


Image classification associates a single label with the whole image, i.e., the outputs are assumed to 
be mutually exclusive. In many problems, there may be multiple objects present, and we want to 
label all of them. This is known as image tagging, and is an application of multi-label prediction. 
In this case, we define the output space as Y = {0,1}, where C is the number of tag types. Since 
the output bits are independent (given the image), we should replace the final softmax with a set of 
C logistic units. 


7. This example is from https://bit.ly/2Uj64Vo by Chi-Feng Wang. 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


14.5. Solving other discriminative vision tasks with CNNs * 487 


input (C) 


2d convolution 1x1 convolution 
per channel maps from C to D dimensions 


Figure 14.26: Depthwise separable convolutions: each of the C input channels undergoes a 2d convolution to 
produce C output channels, which get combined pointwise (via 1x1 convolution) to produce D output channels. 
From https: //bit. ly/2L9fm20. Used with kind permission of Eugenio Culurciello. 


Users of social media sites like Instagram often create hashtags for their images; this therefore 
provides a “free” way of creating large supervised datasets. Of course, many tags may be quite 
sparsely used, and their meaning may not be well-defined visually. (For example, someone may take 
a photo of themselves after they get a COVID test and tag the image “#covid”; however, visually 
it just looks like any other image of a person.) Thus this kind of user-generated labeling is usually 
considered quite noisy. However, it can be useful for “pre-training”, as discussed in [Mah+18]. 

Finally, it is worth noting that image tagging is often a much more sensible objective than image 
classification, since many images have multiple objects in them, and it can be hard to know which one 
we should be labeling. Indeed, Andrej Karpathy, who created the “human performance benchmark” 
on ImageNet, noted the following:® 


Both [CNNs] and humans struggle with images that contain multiple ImageNet classes (usually 
many more than five), with little indication of which object is the focus of the image. This error 
is only present in the classification setting, since every image is constrained to have exactly one 
correct label. In total, we attribute 16% of human errors to this category. 


14.5.2 Object detection 


In some cases, we want to produce a variable number of outputs, corresponding to a variable number 
of objects of interest that may be present in the image. (This is an example of an open world 
problem, with an unknown number of objects.) 

A canonical example of this is object detection, in which we must return a set of bounding 
boxes representing the locations of objects of interest, together with their class labels. A special 
case of this is face detection, where there is only one class of interest. This is illustrated in 
Figure 14.27a.9 


8. Source: https://bit.ly/3cFbALk 
9. Note that face detection is different from face recognition, which is a classification task that tries to predict the 
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Figure 14.27: (a) Illustration of face detection, a special case of object detection. (Photo of author and his 
wife Margaret, taken at Filoli in California in Feburary, 2018. Image processed by Jonathan Huang using 
SSD face model.) (b) Illustration of anchor boxes. Adapted from [Zha+20, Sec 12.5]. 


The simplest way to tackle such detection problems is to convert it into a closed world problem, in 
which there is a finite number of possible locations (and orientations) any object can be in. These 
candidate locations are known as anchor boxes. We can create boxes at multiple locations, scales 
and aspect ratios, as illustrated in Figure 14.27b. For each box, we train the system to predict what 
category of object it contains (if any); we can also perform regression to predict the offset of the 
object location from the center of the anchor. (These residual regression terms allow sub-grid spatial 
localization.) 

Abstractly, we are learning a function of the form 


foi Ree * $ OIE aG E (RAA (14.27) 


where K is the number of input channels, A is the number of anchor boxes in each dimension, and C 
is the number of object types (class labels). For each box location (i, j), we predict three outputs: 
an object presence probability, p;; € [0,1], an object category, yi; € {1,...,C}, and two 2d offset 
vectors, ĝ;j € R4, which can be added to the centroid of the box to get the top left and bottom right 
corners. 

Several models of this type have been proposed, including the single shot detector model of 
[Liu+16], and the YOLO (you only look once) model of [Red+16]. Many other methods for object 
detection have been proposed over the years. These models make different tradeoffs between speed, 
accuracy, simplicity, etc. See [Hua+17b] for an empirical comparison, and [Zha+18] for a more recent 
review. 


14.5.3 Instance segmentation 


In object detection, we predict a label and bounding box for each object. In instance segmentation, 
the goal is to predict the label and 2d shape mask of each object instance in the image, as illustrated 
in Figure 14.28. This can be done by applying a semantic segmentation model to each detected box, 


identity of a person from a set or “gallery” of possible people. Face recognition is usually solved by applying the 
classifier to all the patches that are detected as containing faces. 
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Figure 14.28: Illustration of object detection and instance segmentation using Mask R-CNN. From https: 
// github. com/matterport/Mask_RCNN. Used with kind permission of Waleed Abdulla. 


Convolutional Encoder-Decoder 


Output 


Input 


Pooling Indices ` 


RGB Image ME Conv + Batch Normalisation + ReLU Segmentation 
HE Pooling MMB upsampling = Softmax 


Figure 14.29: Illustration of an encoder-decoder (aka U-net) CNN for semantic segmentation. The encoder 
uses convolution (which downsamples), and the decoder uses transposed convolution (which upsamples). From 
Figure 1 of [BKC17]. Used with kind permission of Alex Kendall. 


which has to label each pixel as foreground or background. (See Section 14.5.4 for more details on 
semantic segmentation.) 


14.5.4 Semantic segmentation 


In semantic segmentation, we have to predict a class label y; € {1,...,C} for each pixel, where 
the classes may represent things like sky, road, car, etc. In contrast to instance segmentation, which 
we discussed in Section 14.5.3, all car pixels get the same label, so semantic segmentation does not 
differentiate between objects. We can combine semantic segmentation of “stuff” (like sky, road) and 
instance segmentation of “things” (like car, person) into a coherent framework called “panoptic 
segmentation” [Kir+19]. 

A common way to tackle semantic segmentation is to use an encoder-decoder architecture, as 
illustrated in Figure 14.29. The encoder uses standard convolution to map the input into a small 2d 
bottleneck, which captures high level properties of the input at a coarse spatial resolution. (This 
typically uses a technique called dilated convolution that we explain in Section 14.4.1, to capture a 
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Figure 14.30: Illustration of the U-Net model for semantic segmentation. Each blue box corresponds to a 
multi-channel feature map. The number of channels is shown on the top of the box, and the height/width is 
shown in the bottom left. White boxes denote copied feature maps. The different colored arrows correspond to 
different operations. From Figure 1 from [RFB15]. Used with kind permission of Olaf Ronenberg. 


Input Depth Normals Labels 


Figure 14.31: Illustration of a multi-task dense prediction problem. From Figure 1 of [EF15]. Used with kind 
permission of Rob Fergus. 


large field of view, i.e., more context.) The decoder maps the small 2d bottleneck back to a full-sized 
output image using a technique called transposed convolution that we explain in Section 14.4.2. Since 
the bottleneck loses information, we can also add skip connections from input layers to output layers. 
We can redraw this model as shown in Figure 14.30. Since the overall structure resembles the letter 
U, this is also known as a U-net [RFB15]. 

A similar encoder-decoder architecture can be used for other dense prediction or image-to- 
image tasks, such as depth prediction (predict the distance from the camera, z; € R, for each 
pixel i), surface normal prediction (predict the orientation of the surface, z; € R3, at each image 
patch), etc. We can of course train one model to solve all of these tasks simultaneously, using multiple 
output heads, as illustrated in Figure 14.31. (See e.g., [Kok17] for details.) 


14.5.5 Human pose estimation 


We can train an object detector to detect people, and to predict their 2d shape, as represented by a 
mask. However, we can also train the model to predict the location of a fixed set of skeletal keypoints, 
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Figure 14.32: Illustration of keypoint detection for body, hands and face using the OpenPose system. From 
Figure 8 of [Cao+18]. Used with kind permission of Yaser Sheikh. 


e.g., the location of the head or hands. This is called human pose estimation. See Figure 14.32 
for an example. There are several techniques for this, e.g., PersonLab [Pap+18] and OpenPose 
[Cao+18]. See [Bab19] for a recent review. 

We can also predict 3d properties of each detected object. The main limitation is the ability to 
collect enough labeled training data, since it is difficult for human annotators to label things in 3d. 
However, we can use computer graphics engines to create simulated images with infinite ground 
truth 3d annotations (see e.g., [GNK18]). 


14.6 Generating images by inverting CNNs * 


A CNN trained for image classification is a discriminative model of the form p(y|a), which takes as 
input an image, and returns as output a probability distribution over C class labels. In this section we 
discuss how to “invert” this model, by converting it into a (conditional) generative image model 
of the form p(a|y). This will allow us to generate images that belong to a specific class. (We discuss 
more principled approaches to creating generative models for images in the sequel to this book, 
[Mur23].) 


14.6.1 Converting a trained classifier into a generative model 


We can define a joint distribution over images and labels using p(x, y) = p(x)p(y|a), where p(y|x) 
is the CNN classifier, and p(x) is some prior over images. If we then clamp the class label to a 
specific value, we can create a conditional generative model using p(æ|y) x p(x)p(y|x). Note that the 
discriminative classifier p(y|a) was trained to “throw away” information, so p(y|x) is not an invertible 
function. Thus the prior term p(a) will play an important role in regularizing this process, as we see 
in Section 14.6.2. 

One way to sample from this model is to use the Metropolis Hastings algorithm (Section 4.6.8.4), 
treating €.(x) = log p(y = cla) + log p(a) as the energy function. Since gradient information is 
available, we can use a proposal of the form q(æ'|x) = N (u(æ), eI), where u(x) = æ + $V log E.(x). 
This is called the Metropolis-adjusted Langevin algorithm (MALA). As an approximation, we 
can ignore the rejection step, and accept every proposal. This is called the unadjusted Langevin 
algorithm, and was used in [Ngu+17] for conditional image generation. In addition, we can scale 
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the gradient of the log prior and log likelihood independently. Thus we get an update over the space 
of images that looks like a noisy version of SGD, except we take derivatives wrt the input pixels 
(using Equation (13.50)), instead of the parameters: 


Olog p(x) _ Alogply = cla) 2 
+e + N (0, 631 

Ox . Ox, ( 3 ) 
We can interpret each term in this equation as follows: the e term ensures the image is plausible 
under the prior, the eg term ensures the image is plausible under the likelihood, and the e3 term is a 
noise term, in order to generate diverse samples. If we set €3 = 0, the method becomes a deterministic 
algorithm to (approximately) generate the “most likely image” for this class. 


Lt4+1 = Lt +é (14.28) 


14.6.2 Image priors 


In this section, we discuss various kinds of image priors that we can use to regularize the ill-posed 
problem of inverting a classifier. These priors, together with the image that we start the optimization 
from, will determine the kinds of outputs that we generate. 


14.6.2.1 Gaussian prior 


Just specifying the class label is not enough information to specify the kind of images we want. We 
also need a prior p(x) over what constitutes a “plausible” image. The prior can have a large effect on 
the quality of the resulting image, as we show below. 

Arguably the simplest prior is p(a) = M (æ]|0, I), as suggested in [SVZ14]. (This assumes the image 
pixels have been centered.) This can prevent pixels from taking on extreme values. In this case, the 
update due to the prior term has the form 


1 
Va log p(x) = Ve [sles — o] = T; (14.29) 


Thus the overall update (assuming €2 = 1 and e3 = 0) has the form 


O log p(y = c|x) 
Ox, 


See Figure 14.33 for some samples generated by this method. 


Ley1 = (1 —e1)ay + (14.30) 


14.6.2.2 Total variation (TV) prior 


We can generate slightly more realistic looking images if we use additional regularizers. [MV15; 
MV 16] suggested computing the total variation or TV norm of the image. This is equal to the 
integral of the per-pixel gradients, which can be approximated as follows: 


TV(x) = X (tijk — Zi+1,j,k)? + (Tijk — Ti j+1,k) (14.31) 
ijk 
where Tijk is the pixel value in row i, column j and channel k (for RGB images). We can rewrite 
this in terms of the horizontal and vertical Sobel edge detector applied to each channel: 


TV (2) = $ [H (2.,x)lli + IV (2x)? (14.32) 
k 
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goose ostrich 


Figure 14.33: Images that maximize the probability of ImageNet classes “goose” and “ostrich” under a simple 
Gaussian prior. From http: //yosinski. com/deepvis. Used with kind permission of Jeff Clune. 


Horizontal Deltas: Original Vertical Deltas: Original 


(a) () 


Figure 14.34: Illustration of total variation norm. (a) Input image: a green sea turtle (Used with kind 
permission of Wikimedia author P. Lindgren). (b) Horizontal deltas. (c) Vertical deltas. Adapted from 
https: // www. tensorflow. org/tutorials/generative/style_ transfer. 


See Figure 14.34 for an illustration of these edge detectors. Using p(x) x exp(—TV(a)) discourages 
images from having high frequency artefacts. In [Yos+15], they use Gaussian blur instead of TV 
norm, but this has a similar effect. 

In Figure 14.35 we show some results of optimizing log p(y = c, x) using a TV prior and a CNN 
likelihood for different class labels c starting from random noise. 


14.6.3 Visualizing the features learned by a CNN 


It is interesting to ask what the “neurons” in a CNN are learning. One way to do this is to start with 
a random image, and then to optimize the input pixels so as to maximize the average activation of a 
particular neuron. This is called activation maximization (AM), and uses the same technique as 
in Section 14.6.1 but fixes an internal node to a specific value, rather than clamping the output class 
label. 

Figure 14.36 illustrates the output of this method (with the TV prior) when applied to the AlexNet 
CNN trained on Imagenet classification. We see that, as the depth increases, neurons are learning to 
recognize simple edges/blobs, then texture patterns, then object parts, and finally whole objects. 
This is believed to be roughly similar to the hierarchical structure of the visual cortex (see e.g., 
[Kan+12]). 
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Anemone Fish Banana Parachute Screw 


Figure 14.35: Images that maximize the probability of certain ImageNet classes under a TV prior. From 
https: //research. googleblog. com/2015/ 06/ inceptionism-going-deeper-into-neural. html. Used 
with kind permission of Alecander Mordvintsev. 
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Figure 14.86: We visualize “optimal stimuli” for neurons in layers Conv 1, 8, 5 and fc8 in the AlexNet 
architecture, trained on the ImageNet dataset. For Convd, we also show retrieved real images (under the 
column “data driven”) that produce similar activations. Based on the method in [MV16]. Used with kind 
permission of Donglai Wei. 


An alternative to optimizing in pixel space is to search the training set for images that maximally 
activate a given neuron. This is illustrated in Figure 14.36 for the Conv5 layer. 
For more information on feature visualization see e.g., [OMS17]. 


14.6.4 Deep Dream 


So far we have focused on generating images which maximize the class label or some other neuron of 
interest. In this section we tackle a more artistic application, in which we want to generate versions 
of an input image that emphasize certain features. 

To do this, we view our pre-trained image classifier as a feature extractor. Based on the results 
in Section 14.6.3, we know the activity of neurons in different layers correspond to different kinds 
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(a) (b) 


Figure 14.37: Illustration of DeepDream. The CNN is an Inception classifier trained on ImageNet. (a) 
Starting image of an Aurelia aurita (also called moon jelly). (b) Image generated after 10 iterations. (c) 
Image generated after 50 iterations. From https: // en. wikipedia. org/wiki/DeepDream. Used with kind 
permission of Wikipedia author Martin Thoma. 


of features in the image. Suppose we are interested in “amplifying” features from layers | € £. 
We can do this by defining an energy or loss function of the form L(x) = Xec ;(#), where 
Pı = TVS Vhwe Pinwe() is the feature vector for layer I. We can now use gradient descent to 
optimize this energy. The resulting process is called DeepDream [MOT15], since the model amplifies 
features that were only hinted at in the original image and then creates images with more and more 
of them.!° 

Figure 14.37 shows an example. We start with an image of a jellyfish, which we pass into a CNN 
that was trained to classify ImageNet images. After several iterations, we generate some image which 
is a hybrid of the input and the kinds of “hallucinations” we saw in Figure 14.33; these hallucinations 
involve dog parts, since ImageNet has so many kinds of dogs in its label set. See [Tho16] for details, 
and https: //deepdreamgenerator.com for a fun web-based demo. 


14.6.5 Neural style transfer 


The DeepDream system in Figure 14.37 shows one way that CNNs can be used to create “art”. 
However, it is rather creepy. In this section, we discuss a related approach that gives the user more 
control. In particular, the user has to specify a reference “style image” x, and “content image” £e. 
The system will then try to generate a new image æ that “re-renders” £e in the style of x,. This is 
called neural style transfer, and is illustrated in Figure 14.38 and Figure 14.39. This technique 
was first proposed in [GEB16], and there are now many papers on this topic; see [Jin+17] for a recent 
review. 


14.6.5.1 How it works 


Style transfer works by optimizing the following energy function: 
L(T|£s, Le) = ArvLry (x) T AcLcontent (x, £e) + AsLstyle(a, Ls) (14.33) 
See Figure 14.40 for a high level illustration. 


10. The method was originally called Inceptionism, since it uses the inception CNN (Section 14.3.3). 
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(b) 


Figure 14.38: Example output from a neural style transfer system. (a) Content image: a green sea turtle (Used 
with kind permission of Wikimedia author P. Lindgren). (b) Style image: a painting by Wassily Kandinsky 
called “Composition 7”. (c) Output of neural style generation. Adapted from https: //www. tensorflow. 
org/ tutorials/ generative/style_ transfer. 


Figure 14.39: Neural style transfer applied to photos of the “production team”, who helped create code and 
demos for this book and its sequel. From top to bottom, left to right: Kevin Murphy (the author), Mahmoud 
Soliman, Aleyna Kara, Srikar Jilugu, Drishti Patel, Ming Liang Ang, Gerardo Durdn-Martin, Coco (the 
team dog). Each content photo used a different artistic style. Adapted from https: // www. tensorflow. org/ 
tutorials/ generative/style_ transfer. 
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Figure 14.40: Illustration of how neural style transfer works. Adapted from Figure 12.12.2 of [Zha+20]. 
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Figure 14.41: Schematic representation of 3 kinds of feature maps for 3 different input images. Adapted from 
Figure 5.16 of [Fos19]. 


The first term in Equation (14.33) is the total variation prior discussed in Section 14.6.2.2. The 
second term measures how similar x is to £e by comparing feature maps of a pre-trained CNN ¢(a) 
in the relevant “content aa l: 


IIpe() — be(xe)|L3 (14.34) 


L content (a, Le) = “<r W: 

Finally we have to define the style term. We can interpret visual style as the statistical distribution 
of certain kinds of image features. The location of these features in the image may not matter, but 
their co-occurence does. This is illustrated in Figure 14.41. It is clear (to a human) that image 1 is 
more similar in style to image 2 than to image 3. Intuitively this is because both image 1 and image 
2 have spiky green patches in them, whereas image 3 has spiky things that are not green. 
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To capture the co-occurence statistics we compute the Gram matrix for an image using feature 
maps from a specific layer £: 


He We 


Ge(X)ea = TT 5 5 pı(x Ja, ,wW,C p(x )hjw,d (14.35) 


h=1w=1 


The Gram matrix is a Cy x Ce matrix which is proportional to the uncentered covariance of the 
C;-dimensional feature vectors sampled over each of the HzW~ locations. 
Given this, we define the style loss for layer £ as follows: 


Letyie(#, Ls) = ||Ge(a) — Ge(as)||% (14.36) 


Finally, we define the overall style loss as a sum over the losses for a set S of layers: 


Lstyle( T£, Ls) = ea (x, Ls) (14.37) 
LES 


For example, in Figure 14.40, we compute the style loss at layers 1 and 3. (Lower layers will capture 
visual texture, and higher layers will capture object layout.) 


14.6.5.2 Speeding up the method 


In [GEB16], they used L-BFGS (Section 8.3.2) to optimize Equation (14.33), starting from white 
noise. We can get faster results if we use an optimizer such as Adam instead of BFGS, and initialize 
from the content image instead of white noise. Nevertheless, running an optimizer for every new 
style and content image is slow. Several papers (see e.g., [JAFF16; Uly+16; UVL16; LW16]) have 
proposed to train a neural network to directly predict the outcome of this optimization, rather than 
solving it for each new image pair. (This can be viewed as a form of amortized optimization.) In 
particular, for every style image Œs, we fit a model fs such that f,(a.) = argmin,, L(£|£s, £e). We 
can then apply this model to new content images without having to reoptimize. 

More recently, [DSK16] has shown how it is possible to train a single network that takes as 
input both the content and a discrete representation s of the style, and then produces f(a-.,s) = 
argmin, L(æ|s, £<) as the output. This avoids the need to train a separate network for every style 
image. The key idea is to standardize the features at a given layer using scale and shift parameters 
that are style specific. In particular, we use the following conditional instance normalization 
transformation: 


CIN($(ac), 8) = Ys Gare eee) + Bs (14.38) 


where j1(¢(a@-)) is the mean of the features in a given layer, o(¢(a-)) is the standard deviation, 
and 8, and ys are parameters for style type s. (See Section 14.2.4.2 for more details on instance 
normalization.) Surprisingly, this simple trick is enough to capture many kinds of styles. 

The drawback of the above technique is that it only works for a fixed number of discrete styles. 
[HB17] proposed to generalize this by replacing the constants 6, and ys by the output of another CNN, 
which takes an arbitrary style image x, as input. That is, in Equation (14.38), we set Bs = fg(¢(as)) 
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and ys = fy(¢(as)), and we learn the parameters G and y along with all the other parameters. The 


model becomes 


AIN(0(€.),6(0)) = foles) (SEXED) & flyes) (14.39) 


They call their method adaptive instance normalization. 
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1 5 Neural Networks for Sequences 


15.1 Introduction 


In this chapter, we discuss various kinds of neural networks for sequences. We will consider the 
case where the input is a sequence, the output is a sequence, or both are sequences. Such models 
have many applications, such as machine translation, speech recognition, text classification, image 
captioning, etc. Our presentation borrows from parts of [Zha+20], which should be consulted for 
more details. 


15.2 Recurrent neural networks (RNNs) 


A recurrent neural network or RNN is a neural network which maps from an input space of 
sequences to an output space of sequences in a stateful way. That is, the prediction of output y 
depends not only on the input æ+, but also on the hidden state of the system, h+, which gets updated 
over time, as the sequence is processed. Such models can be used for sequence generation, sequence 
classification, and sequence translation, as we explain below.! 


15.2.1 Vec2Seq (sequence generation) 


In this section, we discuss how to learn functions of the form fg : RP + RNC, where D is the 
size of the input vector, and the output is an arbitrary-length sequence of vectors, each of size C. 
(Note that words are discrete tokens, but can be converted to real-valued vectors as we discuss in 
Section 1.5.4.) We call these vec2seq models, since they map a vector to a sequence. 

The output sequence yı:r is generated one token at a time. At each step we sample 7% from the 
hidden state h; of the model, and then “feed it back in” to the model to get the new state hi, 
(which also depends on the input a). See Figure 15.1 for an illustration. In this way the model 
defines a conditional generative model of the form p(y1.7|x), which captures dependencies between 
the output tokens. We explain this in more detail below. 


1. For a more detailed introduction, see http: //karpathy. github. io/2015/05/21/rnn-effectiveness/. 
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yı y2 Y3 


T 


Figure 15.1: Recurrent neural network (RNN) for generating a variable length output sequence yı:r given an 
optional fixed length input vector æ. 


15.2.1.1 Models 


For notational simplicity, let T be the length of the output (with the understanding that this is 
chosen dynamically). The RNN then corresponds to the following conditional generative model: 


T 
P(yi-T|£) = 5 pP(YT, hi:r|æ) = 5 [[o@alhe)p rile, y1, 2) (15.1) 


hi:r hy.7 t=1 


where h; is the hidden state, and where we define p(hilho, yo, £) = p(hilx) as the initial hidden 
state distribution (often deterministic). 
The output distribution is usually given by 


p(y:|he) = Cat (y:|softmax(W nyh + by)) (15.2) 


where Wp, are the hidden-to-output weights, and b, is the bias term. However, for real-valued 
outputs, we can use 


plyidhi) = N (yW nyh: + by, 071) (15.3) 

We assume the hidden state is computed deterministically as follows: 

plhi|hi—1, Yi-1, ©) = I (hi = f(hi-1, Yi-1, £)) (15.4) 
for some deterministic function f. The update function f is usually given by 

he = (Won [@; yr—1] + Wrnrhi—1 + bn) (15.5) 


where Wpn are the hidden-to-hidden weights, Wyp are the input-to-hidden weights, and bp are the 
bias terms. See Figure 15.1 for an illustration, and rnn_jax.ipynb for some code. 

Note that y+ depends on h+, which depends on y;_;, which depends on h+—1, and so on. Thus y; 
implicitly depends on all past observations (as well as the optional fixed input x). Thus an RNN 
overcomes the limitations of standard Markov models, in that they can have unbounded memory. 
This makes RNNs theoretically as powerful as a Turing machine [SS95; PMB19]. In practice, 
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the githa some thong the time traveller held in his hand was a glitteringmetallic framework scarcely larger than a 
small clock and verydelicately made there was ivory in it and the latter than s bettyre tat howhong s ie time thave ler 
simk you a dimensions le ghat dionthat shall travel indifferently in any direction of space and timeas the driver 
determinesfilby contented himself with laughterbut i have experimental verification said the time travellerit would be 


remarkably convenient for the histo 


Figure 15.2: Example output of length 500 generated from a character level RNN when given the prefix “the”. 
We use greedy decoding, in which the most likely character at each step is computed, and then fed back into 
the model. The model is trained on the book The Time Machine by H. G. Wells. Generated by rnn_jax.ipynb. 


however, the memory length is determined by the size of the latent state and the strength of the 
parameters; see Section 15.2.7 for further discussion of this point. 

When we generate from an RNN, we sample from J; ~ p(y;|h;), and then “feed in” the sampled 
value into the hidden state, to deterministically compute hii; = f (hi, Je, £), from which we sample 
Yr ~ P(yr+i1|he41), etc. Thus the only stochasticity in the system comes from the noise in the 
observation (output) model, which is fed back to the system in each step. (However, there is a variant, 
known as a variational RNN [Chu-+15], that adds stochasticity to the dynamics of h, independent 
of the observation noise.) 


15.2.1.2 Applications 


RNNs can be used to generate sequences unconditionally (by setting æ = Ø) or conditionally on g. 
Unconditional sequence generation is often called language modeling; this refers to learning joint 
probability distributions over sequences of discrete tokens, i.e., models of the form p(yi1,..., yr). (See 
also Section 3.6.1.2, where we discuss using Markov chains for language modeling.) 

Figure 15.2 shows a sequence generated from a simple RNN trained on the book The Time Machine 
by H. G. Wells. (This is a short science fiction book, with just 32,000 words and 170k characters.) We 
see that the generated sequence looks plausible, even though it is not very meaningful. By using more 
sophisticated RNN models (such as those that we discuss in Section 15.2.7.1 and Section 15.2.7.2), 
and by training on more data, we can create RNNs that give state-of-the-art performance on the 
language modeling task [CNB17]. (In the language modeling community, performance is usually 
measured by perplexity, which is just the exponential of the average per-token negative log likelihood; 
see Section 6.1.5 for more information.) 

We can also make the generated sequence depend on some kind of input vector æ. For example, 
consider the task of image captioning: in this case, x is some embedding of the image computed 
by a CNN, as illustrated in Figure 15.3. See e.g., [Hos+19; LXW19] for a review of image captioning 
methods, and https://bit.ly/2WvsiGK for a tutorial with code. 

It is also possible to use RNNs to generate sequences of real-valued feature vectors, such as pen 
strokes for hand-written characters [Gral3] and hand-drawn shapes [HE18]. This can also be useful 
for time series forecasting real-value sequences. 


15.2.2 Seq2Vec (sequence classification) 


In this section, we assume we have a single fixed-length output vector y we want to predict, given a 
variable length sequence as input. Thus we want to learn a function of the form fg : RTP > RO. We 
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Figure 15.3: Illustration of a CNN-RNN model for image captioning. The pink boxes labeled “LSTM” refer to 
a specific kind of RNN that we discuss in Section 15.2.7.2. The pink boxes labeled Wem» refer to embedding 
matrices for the (sampled) one-hot tokens, so that the input to the model is a real-valued vector. From 
https: // bit. ly/2FKnqHm. Used with kind permission of Yunjey Choi. 


Tı T2 T3 


Figure 15.4: (a) RNN for sequence classification. (b) Bi-directional RNN for sequence classification. 


call this a seq2vec model. We will focus on the case where the output is a class label, y € {1,..., C}, 
for notational simplicity. 


The simplest approach is to use the final state of the RNN as input to the classifier: 


ply|xı:r) = Cat(y|softmax(W hr)) (15.6) 


See Figure 15.4a for an illustration. 

We can often get better results if we let the hidden states of the RNN depend on the past and 
future context. To do this, we create two RNNs, one which recursively computes hidden states in the 
forwards direction, and one which recursively computes hidden states in the backwards direction. 
This is called a bidirectional RNN [SP97]. 
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Figure 15.5: (a) RNN for transforming a sequence to another, aligned sequence. (b) Bi-directional RNN for 
the same task. 


More precisely, the model is defined as follows: 


h? = p(W ja + Wap he + br ) (15.7) 
hr = p(W i, xt T inhi a3 bj, ) (15.8) 


We can then define h; = [h7 , hj] to be the representation of the state at time t, taking into account 
past and future information. Finally we average pool over these hidden states to get the final classifier: 


p(y|a1-7) = Cat(y|Wsoftmax(h)) (15.9) 
T 
h= ah (15.10) 


See Figure 15.4b for an illustration, and rnn_ sentiment _jax.ipynb for some code. (This is similar to 
the 1d CNN text classifierl in Section 15.3.1.) 


15.2.3 Seq2Seq (sequence translation) 


In this section, we consider learning functions of the form fg : RTP > RTC, We consider two cases: 
one in which T” = T, so the input and output sequences have the same length (and hence are aligned), 
and one in which T’ Æ T, so the input and output sequences have different lengths. This is called a 


seq2seq problem. 


15.2.3.1 Aligned case 


In this section, we consider the case where the input and output sequences are aligned. We can also 
think of it as dense sequence labeling, since we predict one label per location. It is straightforward 
to modify an RNN to solve this task, as shown in Figure 15.5a. This corresponds to 


T 
P(y-rleir) = >> [20h (he = f(hi-1,2:)) (15.11) 


hi:r t=1 
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Ly 


Figure 15.7: Encoder-decoder RNN architecture for mapping sequence «1:7 to sequence yir. 


where we define hy = f(ho, #1) = fo(x1) to be the initial state. 

Note that y, depends on h, which only depends on the past inputs, 21.,. We can get better results 
if we let the decoder look into the “future” of x as well as the past, by using a bidirectional RNN, as 
shown in Figure 15.5b. 

We can create more expressive models by stacking multiple hidden chains on top of each other, as 
shown in Figure 15.6. The hidden units for layer l at time t are computed using 


hi = p(Wirhi | + Wikis + Op) (15.12) 
The output is given by 
o: = Wiohy + bo (15.13) 


15.2.3.2 Unaligned case 


In this section, we discuss how to learn a mapping from one sequence of length T to another of length 
T’. We first encode the input sequence to get the context vector c = f.(x1.7), using the last state of 
an RNN (or average pooling over a biRNN). We then generate the output sequence using an RNN 
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Figure 15.8: (a) Illustration of a seq2seq model for translating English to French. The - character represents 
the end of a sentence. From Figure 2.4 of [Luo16]. Used with kind permission of Minh-Thang Luong. (b) 
Illustration of greedy decoding. The most likely French word at each step is highlighted in green, and then 
fed in as input to the next step of the decoder. From Figure 2.5 of [Luo16]. Used with kind permission of 
Minh-Thang Luong. 


decoder y1.7' = falc). This is called an encoder-decoder architecture [SVL14; Cho+ 14a]. See 
Figure 15.7 for an illustration. 

An important application of this is machine translation. When this is tackled using RNNs, 
it is called neural machine translation (as opposed to the older approach called statistical 
machine translation, that did not use neural networks). See Figure 15.8a for the basic idea, and 
nmt_jax.ipynb for some code which has more details. For a review of the NMT literature, see [Luol6; 
Neul7]. 


15.2.4 Teacher forcing 


When training a language model, the likelihood of a sequence of words w1, w2,..., wp, is given by 
T 

p(wi:r) = J [pwdwi) (15.14) 
t=1 


In an RNN, we therefore set the input to 7; = w;-1 and the output to y, = w+. Note that we 
condition on the ground truth labels from the past, w1:+—1, not labels generated from the model. 
This is called teacher forcing, since the teacher’s values are “force fed” into the model as input at 
each step (i.e., x is set to w1). 

Unfortunately, teacher forcing can sometimes result in models that perform poorly at test time. 
The reason is that the model has only ever been trained on inputs that are “correct”, so it may not 
know what to do if, at test time, it encounters an input sequence w1:4+—1 generated from the previous 
step that deviates from what it saw in training. 
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Figure 15.9: An RNN unrolled (vertically) for 3 time steps, with the target output sequence and loss node 
shown explicitly. From Figure 8.7.2 of [Zha+20]. Used with kind permission of Aston Zhang. 


A common solution to this is known as scheduled sampling [Ben+ 15a]. This starts off using 
teacher forcing, but at random time steps, feeds in samples from the model instead; the fraction of 
time this happens is gradually increased. 

An alternative solution is to use other kinds of models where MLE training works better, such as 
1d CNNs (Section 15.3) and transformers (Section 15.5). 


15.2.5 Backpropagation through time 


We can compute the maximum likelihood estimate of the parameters for an RNN by solving 

0* = argmaxg p(y1.7|@1-7,9), where we have assumed a single training sequence for notational 

simplicity. To compute the MLE, we have to compute gradients of the loss wrt the parameters. 

To do this, we can unroll the computation graph, as shown in Figure 15.9, and then apply the 

backpropagation algorithm. This is called backpropagation through time (BPTT) [Wer90]. 
More precisely, consider the following model: 


hi = Wnt + Wanht-1 (15.15) 
or = Wroht (15.16) 


where o+ are the output logits, and where we drop the bias terms for notational simplicity. We 
assume y, are the true target labels for each time step, so we define the loss to be 


T 
1 
L=7 dtu 01) (15.17) 


: -a L aL aL E sce, ois io mh 
We need to compute the derivatives IW IWG and 5 . The latter term is easy, since it is 


local to each time step. However, the first two terms depend ‘on the hidden state, and thus require 
working backwards in time. 
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We simplify the notation by defining 


ht = f(x, ht—1, wn) (15.18) 
os = g(hi, Wo) (15.19) 


By the chain rule, we have 


OL 1 Olly, Ot) 1 z llyr, Or) Og( Re, Wo) Oh 
= 15.2 
Ow, T JE Ow), “T 3 Oo; Ohi Ow), ea) 
We can expand the last term as follows: 
Ohi = Of (az, ht-1, wn) + Of (az, ht-1, Wn) Oht-1 (15.21) 


Ow), Ow), Ohi-1 Ow), 
If we expand this recursively, we find the following result (see the derivation in [Zha+20, Sec 8.7]): 
-1 t 


Ohy ae 1, Wh) 13 Ul Of (xj,hj-1,Wn) \ Of (xi, hi-1, wn) 
Ohj-1 Ow), 


(15.22) 
i=1 \j=i+1 


Unfortunately, this takes O(T) time to compute per time step, for a total of O(T?) overall. It is 
therefore standard to truncate the sum to the most recent K terms. It is possible to adaptively pick 
a suitable truncation parameter K [AFF19]; however, it is usually set equal to the length of the 
subsequence in the current minibatch. 

When using truncated BPTT, we can train the model with batches of short sequences, usually 
created by extracting non-overlapping subsequences (windows) from the original sequence. If the 
previous subsequence ends at time t — 1, and the current subsequence starts at time t, we can “carry 
over” the hidden state of the RNN across batch updates during training. However, if the subsequences 
are not ordered, we need to reset the hidden state. See rnn_jax.ipynb for some sample code that 
illustrates these details. 


15.2.6 Vanishing and exploding gradients 


Unforunately, the activations in an RNN can decay or explode as we go forwards in time, since 
we multiply by the weight matrix Wa, at each time step. Similarly, the gradients in an RNN can 
decay or explode as we go backwards in time, since we multiply the Jacobians at each time step (see 
Section 13.4.2 for details). A simple heuristic is to use gradient clipping (Equation (13.70)). More 
sophisticated methods attempt to control the spectral radius À of the forward mapping, Wyn, as 
well as the backwards mapping, given by the Jacobian Jpn. 

The simplest way to control the spectral radius is to randomly initialize Wpp in such a way as 
to ensure A ~ 1, and then keep it fixed (i.e., we do not learn W),,). In this case, only the output 
matrix Wpno needs to be learned, resulting in a convex optimization problem. This is called an echo 
state network [JH04]. A closely related approach, known as a liquid state machine [MNM02], 
uses binary-valued (spiking) neurons instead of real-valued neurons. A generic term for both ESNs 
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Figure 15.10: Illustration of a GRU. Adapted from Figure 9.1.3 of [Zha+20]. 


and LSMs is reservoir computing [LJ09]. Another approach to this problem is use constrained 
optimization to ensure the Wp matrix remains orthogonal [Vor+17]. 

An alternative to explicitly controlling the spectral radius is to modify the RNN architecture itself, 
to use additive rather than multiplicative updates to the hidden states, as we discuss in Section 15.2.7. 
This significantly improves training stability. 


15.2.7 Gating and long term memory 


RNNs with enough hidden units can in principle remember inputs from long in the past. However, in 
practice “vanilla” RNNs fail to do this because of the vanishing gradient problem (Section 13.4.2). In 
this section we give a solution to this in which we update the hidden state in an additive way, similar 
to a residual net (Section 14.3.4). 


15.2.7.1 Gated recurrent units (GRU) 


In this section, we discuss models which use gated recurrent units (GRU), as proposed in 
[Cho+14a]. The key idea is to learn when to update the hidden state, by using a gating unit. This 
can be used to selectively “remember” important pieces of information when they are first seen. The 
model can also learn when to reset the hidden state, and thus forget things that are no longer useful. 

To explain the model in more detail, we present it in two steps, following the presentation of 
[Zha+20, Sec 8.8]. We assume X; is a N x D matrix, where N is the batch size, and D is the 
vocabulary size. Similarly, H; is a N x H matrix, where H is the number of hidden units at time t. 
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The reset gate R; € RN*Ħ# and update gate Z, € RY*" are computed using 


R: = o(X; WW, +r Hy_ 1 War =i br) (15.23) 
Zi = o(X:Wzz iR H- ıWpz KE bz) (15.24) 


Note that each element of R; and Z; is in [0,1], because of the sigmoid function. 
Given this, we define a “candidate” next state vector using 


H, = tanh(X:W zh + (Ri © Hy-1) Whar + bn) (15.25) 


This combines the old memories that are not reset (computed using R; © H;—1) with the new inputs 
X;. We pass the resulting linear combination through a tanh function to ensure the hidden units 
remain in the interval (—1,1). If the entries of the reset gate R; are close to 1, we recover the 
standard RNN update rule. If the entries are close to 0, the model acts more like an MLP applied to 
X. Thus the reset gate can capture new, short-term information. 

Once we have computed the candidate new state, the model computes the actual new state by 
using the dimensions from the candidate state H, chosen by the update gate, 1 — Z,, and keeping 
the remaining dimensions at their old values of H1: 


H; — Zi © Hy} + (1 — Z+) © H, (15.26) 


When Za = 1, we pass Hy_1,q through unchanged, and ignore X;. Thus the update gate can capture 
long-term dependencies. 

See Figure 15.10 for an illustration of the overall architecture, and gru_jax.ipynb for some sample 
code. 


15.2.7.2 Long short term memory (LSTM) 


In this section, we discuss the long short term memory (LSTM) model of [HS97b], which is a 
more sophisticated version of the GRU (and pre-dates it by almost 20 years). For a more detailed 
introduction, see https: //colah. github. io/posts/2015-08-Understanding-LSTMs. 

The basic idea is to augment the hidden state h; with a memory cell c;. We need three gates to 
control this cell: the output gate O; determines what gets read out; the input gate I, determines 
what gets read in; and the forget gate F, determines when we should reset the cell. These gates 
are computed as follows: 


O: = o(XiWzo + Hy 1 Wpro + bo) (15.27) 
I, = o(X: We; + Hy-1 Wai + bi) (15.28) 
F; = o(X, War + Hy_1 Was + br) (15.29) 


We then compute a candidate cell state: 
Č: = tanh(X;Wae + Hi-1 Wace + bc) (15.30) 


The actual update to the cell is either the candidate cell (if the input gate is on) or the old cell (if 
the not-forget gate is on): 


C: = F; © Cr} + I, © Č (15.31) 
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Figure 15.11: Illustration of an LSTM. Adapted from Figure 9.2.4 of [Zha+20]. 


If F, = 1 and I, = 0, this can remember long term memories.” 


Finally, we compute the hidden state to be a transformed version of the cell, provided the output 
gate is on: 


H; = O; © tanh(C;) (15.32) 


Note that H; is used as the output of the unit as well as the hidden state for the next time step. 
This lets the model remember what it has just output (short-term memory), whereas the cell C; acts 
as a long-term memory. See Figure 15.11 for an illustration of the overall model, and Istm_jax.ipynb 
for some sample code. 

Sometimes we add peephole connections, where we pass the cell state as an additional input 
to the gates. Many other variants have been proposed. In fact, [JZS15] used genetic algorithms to 
test over 10,000 different architectures. Some of these worked better than LSTMs or GRUs, but in 
general, LSTMs seemed to do consistently well across most tasks. Similar conclusions were reached 
in [Gre+17]. More recently, [ZL17] used an RNN controller to generate strings which specify RNN 
architectures, and then trained the controller using reinforcement learning. This resulted in a novel 
cell structure that outperformed LSTM. However, it is rather complex and has not been adopted by 
the community. 


2. One important detail pointed out in [JZS15] is that we need to initialize the bias term for the forget gate by to 
be large, so the sigmoid is close to 1. This ensures that information can easily pass through the C chain over time. 
Without this trick, performance is often much worse. 
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Figure 15.12: Conditional probabilities of generating each token at each step for two different sequences. From 
Figures 9.8.1-9.8.2 of [Zha+20]. Used with kind permission of Aston Zhang. 


15.2.8 Beam search 


The simplest way to generate from an RNN is to use greedy decoding, in which we compute 
fı = argmax, p(y: = ylĝı:t, £) at each step. We can repeat this process until we generate the 
end-of-sentence token. See Figure 15.8b for an illustration of this method applied to NMT. 

Unfortunately greedy decoding will not generate the MAP sequence, which is defined by yj.7 = 
argmax,,.,. P(y1:7|@). The reason is that the locally optimal symbol at step t might not be on the 
globally optimal path. 

As an example, consider Figure 15.12a. We greedily pick the MAP symbol at step 1, which is A. 
Conditional on this, suppose we have p(y2|y1 = A) = [0.1, 0.4, 0.3, 0.2], as shown. We greedily pick 
the MAP symbol from this, which is B. Conditional on this, suppose we have p(y3|y1 = A, y2 = B) = 
(0.2, 0.2, 0.4, 0.2], as shown. We greedily pick the MAP symbol from this, which is C. Conditional 
on this, suppose we have p(y4|yi1 = A, y2 = B,y3 = C) = [0.0, 0.2, 0.2, 0.6], as shown. We greedily 
pick the MAP symbol from this, which is eos (end of sentence), so we stop generating. The overall 
probability of the generated sequence is 0.5 x 0.4 x 0.4 x 0.6 = 0.048. 

Now consider Figure 15.12b. At step 2, suppose we pick the second most probable token, namely 
C. Conditional on this, suppose we have p(y3|yi = A,y2 = C) = [0.1,0.6,0.2,0.1], as shown. 
We greedily pick the MAP symbol from this, which is B. Conditional on this, suppose we have 
p(yaly1 = A, yo = C, y3 = B) = [0.1, 0.2, 0.1, 0.6], as shown. We greedily pick the MAP symbol from 
this, which is eos (end of sentence), so we stop generating. The overall probability of the generated 
sequence is 0.5 x 0.3 x 0.6 x 0.6 = 0.054. So by being less greedy, we found a sequence with overall 
higher likelihood. 

For hidden Markov models, we can use an algorithm called Viterbi decoding (which is an example 
of dynamic programming) to compute the globally optimal sequence in O(TV?) time, where V is 
the number of words in the vocabulary. (See [Mur23] for details.) But for RNNs, computing the 
global optimum takes O(VT), since the hidden state is not a sufficient statistic for the data. 

Beam search is a much faster heuristic method. In this approach, we compute the top K 
candidate outputs at each step; we then expand each one in all V possible ways, to generate VK 
candidates, from which we pick the top K again. This process is illustrated in Figure 15.13. 

It is also possible to extend the algorithm to sample the top K sequences without replacement 
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Figure 15.13: Illustration of beam search using a beam of size K = 2. The vocabulary is Y = {A, B,C, D, E}, 
with size V = 5. We assume the top 2 symbols at step 1 are A,C. At step 2, we evaluate p(y1 = A, y2 = y) 
and p(yı = C, y2 = y) for each y E€ Y. This takes O(KV) time. We then pick the top 2 partial paths, which 
are (yı = A, y2 = B) and (yı = C, y2 = E), and continue in the obvious way. Adapted from Figure 9.8.3 of 
[Zha+20]. 


(i.e., pick the top one, renormalize, pick the new top one, etc.), using a method called stochastic 
beam search. This perturbs the model’s partial probabilities at each step with Gumbel noise. See 
[KHW19] for details. and [SBS20] for a sequential alternative. These sampling methods can improve 
diversity of the outputs. (See also the deterministic diverse beam search method of [Vij+18].) 


15.3 1d CNNs 


Convolutional neural networks (Chapter 14) compute a function of some local neighborhood for each 
input using tied weights, and return an output. They are usually used for 2d inputs, but can also be 
applied in the 1d case, as we discuss below. They are an interesting alternative to RNNs that are 
much easier to train, because they don’t have to maintain long term hidden state. 


15.3.1 1d CNWNs for sequence classification 


In this section, we discuss the use of 1d CNNs for learning a mapping from variable-length sequences 
to a fixed length output, i.e., a function of the form fg : RPT > RỌ, where T is the length of the 
input, D is the number of features per input, and C is the size of the output vector (e.g., class logits). 

A basic 1d convolution operation applied to a 1d sequence is shown in Figure 14.4. Typically the 
input sequence will have D > 1 input channels (feature dimensions). In this case, we can convolve 
each channel separately and add up the result, using a different 1d filter (kernel) for each input 
channel to get zi = Yg £l pi +k,aWa, where k is size of the 1d receptive field, and wg is the filter for 
input channel d. This produces a 1d vector z € RT encoding the input (ignoring boundary effects). 
We can create a vector representation for each location using a different weight vector for each output 
channel c to get zic = doy Tl pitka Wde This implements a mapping from TD to TC. To reduce 
this to a fixed sized vector, z € R°, we can use max-pooling over time to get ze = max; Zie. We can 
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Figure 15.14: Illustration of the TextCNN model for binary sentiment classification. Adapted from Figure 
15.3.5 of [Zha+20]. 


then pass this into a softmax layer. 

In [Kim14], they applied this model to sequence classification. The idea is to embed each word 
using an embedding layer, and then to compute various features using 1d kernels of different 
widths, to capture patterns of different length scales. We then apply max pooling over time, and 
concatenate the results, and pass to a fully connected layer. See Figure 15.14 for an illustration, and 
cnnid_ sentiment _jax.ipynb for some code. 


15.3.2 Causal 1d CNNs for sequence generation 


To use 1d CNNs in a generative setting, we must convert them to a causal CNN, in which each 
output variable only depends on previously generated variables. (This is also called a convolutional 
Markov model.) In particular, we define the model as follows: 


T T t—k 
ply) = | | p(ylynx—1) = ] | Cat(yr|softmax(e(S > w'yr:r+n))) (15.33) 


t=1 t=1 


where w is the convolutional filter of size k, and we have assumed a single nonlinearity y and 
categorical output, for notational simplicity. This is like regular 1d convolution except we “mask out” 
future inputs, so that y only depends on the past values, rather than past and future values. This is 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


516 Chapter 15. Neural Networks for Sequences 


Output 
Dilation = 8 


Hidden Layer 
Dilation = 4 


Hidden Layer 
Dilation = 2 


Hidden Layer 
Dilation = 1 


Input 


Figure 15.15: Illustration of the wavenet model using dilated (atrous) convolutions, with dilation factors of 1, 
2, 4 and 8. From Figure 3 of [Oor+16]. Used with kind permission of Aaron van den Oord. 


called causal convolution. We can of course use deeper models, and we can condition on input 
features x. 

In order to capture long-range dependencies, we can use dilated convolution (Section 14.4.1), as 
illustrated in Figure 15.15. This model has been successfully used to create a state of the art text 
to speech (TTS) synthesis system known as wavenet [Oor+16]. In particular, they stack 10 causal 
1d convolutional layers with dilation rates 1,2,4,...,256,512 to get a convolutional block with an 
effective receptive field of 1024. (They left-padded the input sequences with a number of zeros equal 
to the dilation rate before every layer, so that every layer has the same length.) They then repeat 
this block 3 times to compute deeper features. 

In wavenet, the conditioning information æ is a set of linguistic features derived from an input 
sequence of words; the model then generates raw audio using the above model. It is also possible to 
create a fully end-to-end approach, which starts with raw words rather than linguistic features (see 
[Wan-+17]). 

Although wavenet produces high quality speech, it is too slow for use in production systems. 
However, it can be “distilled” into a parallel generative model [Oor+18]. We discuss these kinds of 
parallel generative models in the sequel to this book, [Mur23]. 


15.4 Attention 


In all of the neural networks we have considered so far, the hidden activations are a linear combination 
of the input activations, followed by a nonlinearity: Z = y(XW), where X € R™%*” are the hidden 
feature vectors, and W € R®*”’ are a fixed set of weights that are learned on a training set to 
produce Z € R™*” outputs. 

However, we can imagine a more flexible model in which the weights depend on the inputs, i.e., 
Z = »(XW(X)). This kind of multiplicative interaction is called attention. More generally, 
we can write Z = y(VW(Q,K)), where Q € R™*4 are a set of queries (derived from X) used to 
describe what each input is “looking for”, K € R™*4 are a set of keys (derived from X) used to 
describe what each input vector contains, and V € R™*” are a set of values V € R™*” (derived 
from X) used to describe how each input should be transmitted to the output. (We usually compute 
these quantities using linear projections of the input, Q = W,X, K = W,;X, and V = W,X.) 
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Figure 15.16: Attention computes a weighted average of a set of values, where the weights are derived by 
comparing the query vector to a set of keys. From Figure 10.3.1 of [Zha+20]. Used with kind permission of 
Aston Zhang. 


When using atttention to compute output zj, we use its corresponding query qj and compare it to 
each key k; to get a similarity score, 0 < aij < 1, where $; @ij = 1; we then set zj = 0, aijui. 
(We assume y(u) = u is the identity function.) For example, suppose V = X, and query q; equally 
matches keys 1 and 2, so aj; = a2; = 0.5; then we have z; = 0.5%; + 0.5%. Thus the outputs 
become a dynamic weighted combination of the inputs, rather than a fixed weighted combination. 
And rather than learning the weight matrix, we learn the projection matrices W,, W;, and W,. We 
explain this in more detail below. 

Note that attention was originally developed for natural language sequence models. However, 
nowadays it is applied to a variety of models, including vision models. Our presentation in the 
following sections is based on [Zha+20, Chap 10.]. 


15.4.1 Attention as soft dictionary lookup 


We will focus on a single output vector, with corresponding query vector q. We can think of attention 
as a dictionary lookup, in which we compare the query q to each key k;, and then retrieve the 
corresponding value v;. To make this lookup operation differentiable, instead of retrieving a single 
value v;, we compute a convex combination of the values, as follows: 


Attn(q, (kı, vı), a (km, Um)) = Attn(q, (Kim; Vi:m)) = 5 ai(q, ki:m Vi € R” (15.34) 


i=1 


where a;(q, ki:m) is the tth attention weight; these weights satisfy 0 < a;(q, ki:m) < 1 for each i 
and > ailq, kiim) =l 

The attention weights can be computed from an attention score function a(q, ki) € R, that 
computes the similarity of query q to key k;. We will discuss several such score function below. 
Given the scores, we can compute the attention weights using the softmax function: 


exp(a(q, ki)) 
Z; exP(a(q, k;)) 
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See Figure 15.16 for an illustration. 

In some cases, we want to restrict attention to a subset of the dictionary, corresponding to valid 
entries. For example, we might want to pad sequences to a fixed length (for efficient minibatching), 
in which case we should “mask out” the padded locations. This is called masked attention. We 
can implement this efficiently by setting the attention score for the masked entries to a large negative 
number, such as —10°, so that the corresponding softmax weights will be 0. (This is analogous to 
causal convolution, discussed in Section 15.3.2.) 


15.4.2 Kernel regression as non-parametric attention 


We can interpret attention in terms of kernel regression, which is a nonparametric model which we 
discuss in Section 16.3.5. In brief this a model where the predicted output at query point x is a 
weighted combination of all the target labels y;, where the weights depend on the similarity of query 
point x to each training point x;: 


f(x) = dei, Bin) yi (15.36) 


where a;(%, Z1:,) > 0 measures the normalized similarity of test input x to training input x;. This 
similarity measure is usually computed by defining the attention score in terms of a density kernel, 
such as the Gaussian: 


Ke(u) = 


2 


e307" (15.37) 


1 
V2n0? 
where ø is called the bandwidth. We then define a(x, xi) =K,(x — xi). 

Because the scores are normalized, we can drop the JEF term. In addition, we rewrite the kernel 


in terms of 6? = 1/0? to get 


K(u; 8) = expl- E) (15.38) 


Plugging this in to Equation (15.36), we get 


f(x) = 2, ay (2, Lien )Yi TEN 
oS erle 
E 2 el exp[—$((a — aj) 8)?] vi (15.40) 
E 2 softmax; -5e -208, 1-5 (a — an)? Yi (15.41) 


We can interpret this as a form of nonparametric attention, where the queries are the test points 
x, the keys are the training inputs a;, and the values are the training labels y;. If we set 6 = 1, 
the resulting attention matrix Aji = a;(£j, 21:n) for test input j is shown in Figure 15.17a. The 
resulting predicted curve is shown in Figure 15.17b. 

The size of the diagonal band in Figure 15.17a, and hence the sparsity of the attention mechanism, 
depends on the parameter 8. If we increase 6, corresponding to reducing the kernel bandwidth, the 
band will get narrower, but the model will start to overfit. 
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Figure 15.17: Kernel regression in 1d. (a) Kernel weight matrix. (b) Resulting predictions on a dense grid of 
test points. Generated by kernel_ regression attention.ipynb. 


15.4.3 Parametric attention 


In Section 15.4.2, we defined the attention score in terms of the Gaussian kernel, comparing a query 
(test point) to each of the values in the training set. However, non-parametric methods do not scale 
well to large training sets, or high-dimensional inputs. We will therefore turn our attention (no pun 
intended) to parametric models, where we have a fixed set of keys and values, and where we compare 
queries and keys in a learned embedding space. 

There are several ways to do this. In the general case, the query q € R1 and the key k € R! may 
have different sizes. To compare them, we can map them to a common embedding space of size h by 
computing W,g and W;k. where W, € R’*? and Wp € R”**. We can then pass these into an 
MLP to get the following additive attention scoring function: 


alq, k) = w} tanh(W,q + W;,k) E€ R (15.42) 


A more computationally efficient approach is to assume the queries and keys both have length 
d, so we can compute q' k directly. If we assume these are independent random variables with 0 
mean and unit variance, the mean of their inner product is 0, and the variance is d. (This follows 
from Equation (2.34) and Equation (2.39).) To ensure the variance of the inner product remains 
1 regardless of the size of the inputs, it is standard to divide by Vd. This gives rise to the scaled 
dot-product attention: 


a(q,k) =q'k/Vd ER (15.43) 


In practice, we usually deal with minibatches of n vectors at a time. Let the corresponding matrices 
of queries, keys and values be denoted by Q € R”*4, K € R™*4, V € R™*”. Then we can compute 
the attention-weighted outputs as follows: 


QK' x 
Attn(Q, K, V) = a E R”X” (15.44) 


where the softmax function softmax is applied row-wise. See attention _jax.ipynb for some sample 
code. 
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Figure 15.18: Illustration of seq2seq with attention for English to French translation. Used with kind permission 
of Minh-Thang Luong. 


(a) (b) 


Figure 15.19: Illustration of the attention heatmaps generated while translating two sentences from Spanish to 
English. (a) Input is “hace mucho frio aqut.”, output is “it is very cold here.”. (b) Input is “¿todavia estan en 
casa?”, output is “are you still at home?”. Note that when generating the output token “home”, the model 
should attend to the input token “casa”, but in fact it seems to attend to the input token “?”. Adapted from 


https: // www. tensorflow. org/tutorials/ text/nmt_with_ attention. 


15.4.4 Seq2Seq with attention 


Recall the seq2seq model from Section 15.2.3. This uses an RNN decoder of the form h? = 
fa(h?_,,yr-1, ©), where c is a fixed-length context vector, representing the encoding of the input 
x1.7. Usually we set c = h$, which is the final state of the encoder RNN (or we use a bidirectional 
RNN with average pooling). However, for tasks such as machine translation, this can result in poor 
performance, since the output does not have access to the input words themselves. We can avoid 
this bottleneck by allowing the output words to directly “look at” the input words. But which inputs 
should it look at? After all, word order is not always preserved across languages (e.g., German often 
puts verbs at the end of a sentence), so we need to infer the alignment between source and target. 

We can solve this problem (in a differentiable way) by using (soft) attention, as first proposed 
in [BCB15; LPM15]. In particular, we can replace the fixed context vector c in the decoder with a 
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dynamic context vector c; computed as follows: 
T 
c =X lhi a hir)hi (15.45) 
i=1 


This uses attention where the query is the hidden state of the decoder at the previous step, hł 4, 
the keys are all the hidden states from the encoder, and the values are also the hidden states from 
the encoder. (When the RNN has multiple hidden layers, we usually take the top layer from the 
encoder, as the keys and values, and the top layer of the decoder as the query.) This context vector 
is concatenated with the input vector of the decoder, y:-1, and fed into the decoder, along with the 
previous hidden state h?_,, to create h?. See Figure 15.18 for an illustration of the overall model. 

We can train this model in the usual way on sentence pairs, and then use it to perform machine 
translation. (See nmt_attention_jax.ipynb for some sample code.) We can also visualize the attention 
weights computed at each step of decoding, to get an idea of which parts of the input the model 
thinks are most relevant for generating the corresponding output. Some examples are shown in 
Figure 15.19. 


15.4.5 Seq2vec with attention (text classification) 


We can also use attention with sequence classifiers. For example [Raj+18] apply an RNN classifier 
to the problem of predicting if a patient will die or not. The input is a set of electronic health 
records, which is a time series containing structured data, as well as unstructured text (clinical 
notes). Attention is useful for identifying “relevant” parts of the input, as illustrated in Figure 15.20. 


15.4.6 Seq+Seq2Vec with attention (text pair classification) 


Suppose we see the sentence “A person on a horse jumps over a log” (call this the premise) and 
then we later read “A person is outdoors on a horse” (call this the hypothesis). We may reasonably 
say that the premise entails the hypothesis, meaning that the hypothesis is more likely given the 
premise.’ Now suppose the hypothesis is “A person is at a diner ordering an omelette”. In this case, 
we would say that the premise contradicts the hypothesis, since the hypothesis is less likely given 
the premise. Finally, suppose the hypothesis is “A person is training his horse for a competition”. 
In this case, we see that the relationship between premise and hypothesis is neutral, since the 
hypothesis may or may not follow from the premise. The task of classifying a sentence pair into these 
three categories is known as textual entailment or “natural language inference”. A standard 
benchmark in this area is the Stanford Natural Language Inference or SNLI corpus [Bow+15]. 
This consists of 550,000 labeled sentence pairs. 

An interesting solution to this classification problem was presented in [Par+16a]; at the time, it 
was the state of the art on the SNLI dataset. The overall approach is sketched in Figure 15.21. Let 
A = (aj,...,@m) be the premise and B = (b,,...,b,) be the hypothesis, where a;,b; € R? are 
embedding vectors for the words. The model has 3 steps. First, each word in the premise, a;, attends 


3. Note that the premise does not logically imply the hypothesis, since the person could be horse-back riding indoors, 
but generally people ride horses outdoors. Also, we are assuming the phrase “a person” refers to the same person in 
the two sentences. 
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Figure 15.20: Example of an electronic health record. In this ecample, 24h after admission to the hospital, the 
RNN classifier predicts the risk of death as 19.9%; the patient ultimately died 10 days after admission. The 
“relevant” keywords from the input clinical notes are shown in red, as identified by an attention mechanism. 


From Figure 3 of [Raj+18]. Used with kind permission of Alvin Rajkomar. 


to each word in the hypothesis, b;, to compute an attention weight 


eij = f(ai)" f (b;) 


(15.46) 


where f : RË — RP is an MLP; we then compute a weighted average of the matching words in the 


hypothesis, 


(15.47) 


Next, we compare a; with 6; by mapping their concatenation to a hidden space using an MLP 


g: RË > RË: 


=1 


vai = g(lai, B,)), a gee eyg M 
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Figure 15.21: Illustration of sentence pair entailment classification using an MLP with attention to align 
the premise (“I do need sleep”) with the hypothesis (“I am tired”). White squares denote active attention 
weights, blue squares are inactive. (We are assuming hard 0/1 attention for simplicity.) From Figure 15.5.2 
of [Zha+20]. Used with kind permission of Aston Zhang. 


Finally, we aggregate over the comparisons to get an overall similarity of premise to hypothesis: 
m 
va =) vai (15.49) 
i=1 


We can similarly compare the hypothesis to the premise using 


exp( __exp(eij) 
a; (15.50) 
-5 eae 


1 EXp(ekz) 
UB, j E e j = 1,...,7 (15.51) 
vp => vB; (15.52) 
j=l 


At the end, we classify the output using another MLP h : R24 > R3: 
ĝ = h([va, vB]) (15.53) 


See entailment _attention_mlp_jax.ipynb for some sample code. 

We can modify this model to learn other kinds of mappings from sentence pairs to output labels. 
For example, in the semantic textual similarity task, the goal is to predict how semantically 
related two input sentences are. A standard dataset for this is the STS Benchmark |Cer+17], 
where relatedness ranges from 0 (meaning unrelated) to 5 (meaning maximally related). 


15.4.7 Soft vs hard attention 


If we force the attention heatmap to be sparse, so that each output can only attend to one input 
location instead of a weighted combination of all of them, the method is called hard attention. We 
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Figure 15.22: Image captioning using attention. (a) Soft attention. Generates “a woman is throwing a frisbee 
in a park”. (b) Hard attention. Generates “a man and a woman playing frisbee in a field”. From Figure 6 of 
[Xu+15]. Used with kind permission of Kelvin Xu. 


compare these two approaches for an image captioning problem in Figure 15.22. Unfortunately, hard 
attention results in a nondifferentiable training objective, and requires methods such as reinforcement 
learning to fit the model. See [Xu+15] for the details. 

It seems from the above examples that these attention heatmaps can “explain” why the model 
generates a given output. However, the interpretability of attention is controversial (see e.g., [JW19; 
WP19; SS19; Bru+19] for discussion). 


15.5 Transformers 


The transformer model [Vas+17] is a seq2seq model which uses attention in the encoder as well as 
the decoder, thus eliminating the need for RNNs, as we explain below. Transformers have been used 
for many (conditional) sequence generation tasks, such as machine translation [Vas+17], constituency 
parsing [Vas+17], music generation [Hua+18], protein sequence generation [Mad+20; Cho+20b], 
abstractive text summarization [Zha+19a], image generation [Par+18] (treating the image as a 
rasterized 1d sequence), etc. 

The transformer is a rather complex model that uses several new kinds of building blocks or layers. 
We introduce these new blocks below, and then discuss how to put them all together.* 


15.5.1 Self-attention 


In Section 15.4.4 we showed how the decoder of an RNN could use attention to the input sequence in 
order to capture contexual embeddings of each input. However, rather than the decoder attending to 
the encoder, we can modify the model so the encoder attends to itself. This is called self attention 
[CDL16; Par+16b]. 

In more detail, given a sequence of input tokens ®%1,..., £n, where x; € R, self-attention can 
generate a sequence of outputs of the same size using 


yi = Attn(x;,(x1,21),.--, (Ln, n)) (15.54) 


4. For a more detailed introduction, see https: //huggingface.co/course/chapter1. 
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Figure 15.23: Illustration of how encoder self-attention for the word “it” differs depending on the input 
contezt. From https: // at. googleblog. com/2017/ 08/ trans former-novel-neural-network. html. Used 
with kind permission of Jakob Uszkoreit. 


where the query is æ;, and the keys and values are all the (valid) inputs x1,..., 2n- 

To use this in a decoder, we can set x; = yj-1, and n = i — 1, so all the previously generated 
outputs are available. At training time, all the outputs are already known, so we can evaluate the 
above function in parallel, overcoming the sequential bottleneck of using RNNs. 

In addition to improved speed, self-attention can give improved representations of context. As 
an example, consider translating the English sentences “The animal didn’t cross the street because 
it was too tired’ and “The animal didn’t cross the street because it was too wide’ into French. To 
generate a pronoun of the correct gender in French, we need to know what “it” refers to (this is called 
coreference resolution). In the first case, the word “it” refers to the animal. In the second case, 
the word “it” now refers to the street. 

Figure 15.23 illustrates how self attention applied to the English sentence is able to resolve this 
ambiguity. In the first sentence, the representation for “it” depends on the earlier representations of 
“animal”, whereas in the latter, it depends on the earlier representations of “street”. 


15.5.2 Multi-headed attention 


If we think of an attention matrix as like a kernel matrix (as discussed in Section 15.4.2), it is natural 
to want to use multiple attention matrices, to capture different notions of similarity. This is the 
basic idea behind multi-headed attention (MHA). In more detail, given a query q € R%, keys 
kj € R&, and values vE R% , we define the i’th attention head to be 


h; = Attn(wq, (W\k;, WL”); }) € R” (15.55) 


where wi E RPaX4a, we € ReeX4e and we) € R?» <4» are projection matrices. We then stack 
the h heads together, and project to R?° using 
hy 
h = MHA(q, {kj vj} =Wo| : | €R”? (15.56) 
hn 


where h; is defined in Equation (15.55), and W, € RPeXhPv, If we set Pgh = pkh = pyh = po, we 
can compute all the output heads in parallel. See multi_head_attention_jax.ipynb for some sample 
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Queries Keys Values 


Figure 15.24: Multi-head attention. Adapted from Figure 9.3.8 of [Zha+20]. 


code. 


15.5.3 Positional encoding 


The performance of “vanilla” self-attention can be low, since attention is permutation invariant, and 
hence ignores the input word ordering. To overcome this, we can concatenate the word embeddings 
with a positional embedding, so that the model knows what order the words occur in. 

One way to do this is to represent each position by an integer. However, neural networks cannot 
natively handle integers. To overcome this, we can encode the integer in binary form. For example, 
if we assume the sequence length is n = 3, we get the following sequence of d = 3-dimensional bit 
vectors for each location: 000, 001, 010, 011, 100, 101, 110, 111. We see that the right most index 
toggles the fastest (has highest frequency), whereas the left most index (most significant bit) toggles 
the slowest. (We could of course change this, so that the left most bit toggles fastest.) We can 
represent this as a position matrix P € R"*¢. 

We can think of the above representation as using a set of basis functions (corresponding to powers 
of 2), where the coefficients are 0 or 1. We can obtain a more compact code by using a different set 
of basis functions, and real-valued weights. |Vas+17] propose to use a sinusoidal basis, as follows: 


Pi 2j = sin (cz) > Pi 2j+1 = COs (cz) á (15.57) 


where C = 10,000 corresponds to some maximum sequence length. For example, if d = 4, the tth 
row is 


(15.58) 


a 
008 375) 


p i 
, sin( C2/4 ) 


a 
), cos( goya) 


Figure 15.25a shows the corresponding position matrix for n = 60 and d = 32. In this case, the 
left-most columns toggle fastest. We see that each row has a real-valued “fingerprint” representing 
its location in the sequence. Figure 15.25b shows some of the basis functions (column vectors) for 
dimensions 6 to 9. 
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Figure 15.25: (a) Positional encoding matrix for a sequence of length n = 60 and an embedding dimension of 
size d = 32. (b) Basis functions for columns 6 to 9. Generated by positional encoding jax. ipynb. 


The advantage of this representation is two-fold. First, it can be computed for arbitrary length 
inputs (up to T < C), unlike a learned mapping from integers to vectors. Second, the representation 
of one location is linearly predictable from any other, given knowledge of their relative distance. In 
particular, we have P+ọ = f (p+), where f is a linear transformation. To see this, note that 


sin(w,(t+¢))\ _ /sin(wrt) cos(wed) + cos(wet) sin(wed) 
fa (t+ i Eo A — sin(wt) A ) (lse) 


_ cos(wró) salud) aa aso 


—sin(wped) cos(wed)/ \cos(wpt) 


So if @ is small, then Pi+ẹ ~ pi. This provides a useful form of inductive bias. 
Once we have computed the positional embeddings P, we need to combine them with the original 
word embeddings X using the following:° 


POS(Embed(X)) = X + P. (15.61) 


15.5.4 Putting it all together 


A transformer is a seq2seq model that uses self-attention for the encoder and decoder rather than 
an RNN. The encoder uses a series of encoder blocks, each of which uses multi-headed attention 
(Section 15.5.2), residual connections (Section 13.4.4), feedforward layers (Section 13.2), and layer 
normalization (Section 14.2.4.2). More precisely, the encoder block can be defined as follows: 


def EncoderBlock(X): 


Z = LayerNorm(MultiHeadAttn(Q=X, K=X, V=X) + X) 
E = LayerNorm(FeedForward(Z) + Z) 
return E 


5. A more obvious combination scheme would be to concatenate, X and P, but adding takes less space. Furthermore, 
since the X embeddings are learned, the model could emulate concatentation by setting the first K dimensions of X, 
and the last D — K dimensions of P, to 0, where K is defined implicitly by the sparsity pattern. For more discussion, 
see https://bit.ly/3rMGlat. 
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Figure 15.26: The transformer. From [Wen18]. Used with kind permission of Lilian Weng. Adapted from 
Figures 1-2 of [Vas+17]. 


Note that the MHA layer combines information across the sequence, and the feedforward layer 
combines information across the dimensions at each location in parallel. (Most of the parameters 
of large transformer models are stored inside these MLPs, and it has been conjectured that this is 
where most of the “world knowledge” lives [Men+22].) The layer norm can either be applied after 
the module (i.e., z = LN(module(x) + x)) or before (i.e., z = module(LN(a) + x)); these are known 
as post-norm and pre-norm. 

The overall encoder is defined by applying positional encoding to the embedding of the input 
sequence, following by N copies of the encoder block, where N controls the depth of the block: 


def Encoder(X, N): 
E = POS(Embed(X)) 
for n in range(N): 
E = EncoderBlock(E) 
return E 


See the LHS of Figure 15.26 for an illustration. 

The decoder has a somewhat more complex structure. It is given access to the encoder via 
another multi-head attention block. But it is also given access to previously generated outputs: these 
are shifted, and then combined with a positional embedding, and then fed into a masked (causal) 
multi-head attention model. Finally the output distribution over tokens at each location is computed 
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Layer type Complexity Sequential ops. Max. path length 
Self-attention O(n?d) O(1) O(1) 

Recurrent O(nd?) O(n) O(n) 
Convolutional O(knd?) O(1) O(log; n) 


Table 15.1: Comparison of the transformer with other neural sequential generative models. n is the sequence 
length, d is the dimensionality of the input features, and k is the kernel size for convolution. Based on Table 
1 of [Vas+17]. 


in parallel. 
In more detail, the decoder block is defined as follows: 


def DecoderBlock(Y, E): 
Z = LayerNorm(MultiHeadAttn(Q=Y, K=Y, V=Y) + Y) 
Z? = LayerNorm(MultiHeadAttn(Q=Z, K=E, V=E) + Z) 
D = LayerNorm(FeedForward(Z’) + Z?) 
return D 


The overall decoder is defined by N copies of the decoder block: 


def Decoder(Y, E, N): 
D = POS(Embed(Y)) 
for n in range(N): 
D = DecoderBlock (D,E) 
return D 


See the RHS of Figure 15.26 for an illustration. 

During training time, all the inputs Y to the decoder are known in advance, since they are derived 
from embedding the lagged target output sequence. During inference (test) time, we need to decode 
sequentially, and use masked attention, where we feed the generated output into the embedding 
layer, and add it to the set of keys/values that can be attended to. (We initialize by feeding in the 
start-of-sequence token.) See transformers _jax.ipynb for some sample code, and [Rus18; Ala18] for a 
detailed tutorial on this model. 


15.5.5 Comparing transformers, CNNs and RNNs 


In Figure 15.27, we visually compare three different architectures for mapping a sequence £1:n to 
another sequence Yyj.;: a 1d CNN, an RNN, and an attention-based model. Each model makes 
different tradeoffs in terms of speed and expressive power, where the latter can be quantified in terms 
of the maximum path length between any two inputs. See Table 15.1 for a summary. 

For a 1d CNN with kernel size k and d feature channels, the time to compute the output is O(knd?), 
which can be done in parallel. We need a stack of n/k layers, or log, (n) if we use dilated convolution, 
to ensure all pairs can communicate. For example, in Figure 15.27, we see that x; and zs are initially 
5 apart, and then 3 apart in layer 1, and then connected in layer 2. 

For an RNN, the computational complexity is O(nd?), for a hidden state of size d, since we have 
to perform matrix-vector multiplication at each step. This is an inherently sequential operation. The 
maximum path length is O(n). 
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Figure 15.27: Comparison of (1d) CNNs, RNNs and self-attention models. From Figure 10.6.1 of [Zha+20]. 
Used with kind permission of Aston Zhang. 


Finally, for self-attention models, every output is directly connected to every input, so the maximum 
path length is O(1). However, the computational cost is O(n?d). For short sequences, we typically 
have n < d, so this is fine. For longer sequences, we discuss various fast versions of attention in 
Section 15.6. 


15.5.6 Transformers for images * 


CNNs (Chapter 14) are the most common model type for processing image data, since they have 
useful built-in inductive bias, such as locality (due to small kernels), equivariance (due to weight 
tying), and invariance (due to pooling). Suprisingly, it has been found that transformers can also do 
well at image classification [Rag+21], at least if trained on enough data. (They need a lot of data to 
overcome their lack of relevant inductive bias.) 

The first model of this kind, known as ViT (vision transformer) [Dos+21], chops the input up into 
16x16 patches, projects each patch into an embedding space, and then passes this set of embeddings 
z£ı:r to a transformer, analogous to the way word embeddings are passed to a transformer. The 
input is also prepended with a special [CLASS] embedding, £o. The output of the transformer is a 
set of encodings ep.r; the model maps eg to the target class label y, and is trained in a supervised 
way. See Figure 15.28 for an illustration. 

After supervised pretraining, the model is fine-tuned on various downstream classification tasks, 
an approach known as transfer learning (see Section 19.2 for more details). When trained on “small” 
datasets such as ImageNet (which has 1k classes and 1.3M images), they find that they cannot 
outperform a pretrained CNN ResNet model (Section 14.3.4) known as BiT (big transfer) [Kol+20]. 
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Figure 15.28: The Vision Transformer (ViT) model. This treats an image as a set of input patches. The 
input is prepended with the special CLASS embedding vector (denoted by *) in location 0. The class label 
for the image is derived by applying softmaz to the final output encoding at location 0. From Figure 1 of 
[Dos+21]. Used with kind permission of Alexey Dosovitskiy 


However, when trained on larger datasets, such as ImageNet-21k (with 21k classes and 14M images), 
or the Google-internal JFT dataset (with 18k classes and 303M images), they find that ViT does 
better than BiT at transfer learning.® ViT is also cheaper to train than ResNet at this scale. (However, 
training is still expensive: the large ViT model on ImageNet-21k takes 30 days on a Google Cloud 
TPUv3 with 8 cores!) 


15.5.7 Other transformer variants * 


Many extensions of transformers have been published in the last few years. For example, the Gshard 
paper [Lep+21] shows how to scale up transformers to even more parameters by replacing some of 
the feed forward dense layers with a mixture of experts (Section 13.6.2) regression module. This 
allows for sparse conditional computation, in which only a subset of the model capacity (chosen by 
the gating network) is used for any given input. 

As another example, the conformer paper [Gul+20] showed how to add convolutional layers 
inside the transformer architecture, which was shown to be helpful for various speech recognition 
tasks. 


15.6 Efficient transformers * 


This section is written by Krzysztof Choromanski. 


Regular transformers take O(N?) time and space complexity, for a sequence of length N, which 
makes them impractical to apply to long sequences. In the past few years, researchers have proposed 
several more efficient variants of transformers to bypass this difficulty. In this section, we give a 


6. More recent work, specifically the ConvNeXt model of [Liu+22], has shown that CNNs can be made be to outperform 
ViT. 
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Figure 15.29: Venn diagram presenting the taronomy of different efficient transformer architectures. From 
[Tay+20b]. Used with kind permission of Yi Tay. 


brief survey of some of these methods (see Figure 15.29 for a summary). For more details, see e.g., 
[Tay+20b; Tay+20a; Lin+21]. 


15.6.1 Fixed non-learnable localized attention patterns 


The simplest modification of the attention mechanism is to constrain it to a fixed non-learnable 
localized window, in other words restrict each token to attend only to a pre-selected set of other 


tokens. If for instance, each sequence is chunked into K blocks, each of length x, and attention 


is conducted only within a block, then space/time complexity is reduced from O(N?) to Ne For 
K > 1 this constitutes substantial overall computational improvements. Such an approach is applied 
in particular in [Qiu+-19b; Par+18]. The attention patterns do not need to be in the form of blocks. 
Other approaches involve strided / dilated windows, or hybrid patterns, where several fixed attention 
patterns are combined together [Chi+19b; BPC20]. 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


15.6. Efficient transformers * 533 


15.6.2 Learnable sparse attention patterns 


A natural extension of the above approach is to allow the above compact patterns to be learned. The 
attention is still restricted to pairs of tokens within a single partition of some partitioning of the set 
of all the tokens, but now those partitionings are trained. In this class of methods we can distinguish 
two main approaches: based on hashing and clustering. In the hashing scenario all tokens are hashed 
and thus different partitions correspond to different hashing-buckets. This is the case for instance for 
the Reformer architecture [KKL20], where locality sensitive hashing (LSH) is applied. That leads 
to time complexity O(N M? log(M)) of the attention module, where M stands for the dimenionsality 
of tokens’ embeddings. 

Hashing approaches require the set of queries to be identical to the set of keys. Furthermore, 
the number of hashes needed for precise partitioning (which in the above expression is treated as a 
constant) can be a large constant. In the clustering approach, tokens are clustered using standard 
clustering algorithms such as K-means (Section 21.3); this is known as the “clustering transformer” 
[Roy+20]. As in the block-case, if K equal-size clusters are used then space complexity of the 
attention module is reduced to CCE). In practice K is often taken to be of order K = O(VN), yet 
imposing that the clusters be similar in size is in practice difficult. 


15.6.3 Memory and recurrence methods 


In some approaches, a side memory module can access several tokens simultaneously. This method is 
often instantiated in the form of a global memory algorithm as used in [Lee+19; Zah+20]. 

Another approach is to connect different local blocks via recurrence. A flagship example of this 
approach is the class of Transformer-XL methods [Dai+ 19]. 


15.6.4 Low-rank and kernel methods 


In this section, we discuss methods that approximate attention using low rank matrices. In [She+18; 
Kat+20] they approximate the attention matrix A directly by a low rank matrix, so that 


Aij = plai)" O(k;) (15.62) 


where (x) € R™ is some finite-dimensional vector with M < D. One can leverage this structure to 
compute AV in O(N) time. Unfortunately, for softmax attention, the A is not low rank. 

In Linformer [Wan+20a], they instead transform the keys and values via random Gaussian pro- 
jections. They then apply the theory of the Johnson-Lindenstrauss Transform [AL13] to approximate 
softmax attention in this lower dimensional space. 

In Performer [Cho+20a; Cho+20b], they show that the attention matrix can be computed using 
a (positive definite) kernel function. We define kernel functions in Section 17.1, but the basic idea 
is that K(q,k) > 0 is some measure of similarity between q € R? and k € R”. For example, the 
Gaussian kernel, also called the radial basis function kernel, has the form 


1 
Kganslak) = exp (—325lla =x) (15.63) 
To see how this can be used to compute an attention matrix, note that [Cho+20a] show the following: 
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K(x,y) = E|d(x)@(y) T } 5: 


P7 


A 


Figure 15.30: Attention matriz A rewritten as a product of two lower rank matrices Q' and (K')" with random 
feature maps p(q:) € R” and (vn) € R™ for the corresponding queries/keys stored in the rows/columns. 
Used with kind permission of Krzysztof Choromanski. 


T 2 2 2 
qi kj zla: = Kyle la:ll2 lk;lE 
—=| ) = exp| —_ —_) X ex x €xp| —{]} 
VD ) ( 2VD ) GVD” p( 2v D 


The first term in the above expression is equal to Kgauss(qi D7}, k; D4) with o = 1, and the 
other two terms are just independent scaling factors. 

So far we have not gained anything computationally. However, we will show in Section 17.2.9.3 
that the Gaussian kernel can be written as the expectation of a set of random features: 


)'n(y)] (15.65) 


where n(x) € R™ is a random feature vector derived from æ, either based on trigonometric functions 
Equation (17.60) or exponential functions Equation (17.61). (The latter has the advantage that all 
the features are positive, which gives much better results [Cho+20b].) Therefore for the regular 
softmax attention, A; į can be rewritten as 


Aij = exp( ). (15.64) 


Kegauss(&, y) = 6 [n(x 


A; j = E[b(qi)' o(k,)] (15.66) 


where @ is defined as: 


(x) = exp (£L) n (=) (15.67) 


We can write the full attention matrix as follows 


A = E/Q’(K’)"] (15.68) 
where Q’, K’ c RNM have rows encoding random feature maps corresponding to the queries and 
keys. (Note that we can get better performance if we ensure these random features are orthogonal, 


see [Cho+20a] for the details.) See Figure 15.30 for an illustration. 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


15.7. Language models and unsupervised representation learning 535 


Figure 15.31: Decomposition of the attention matrix A can be leveraged to improve attention computations 
via matriz associativity property. To compute AV, we first calculate G = (kV and then q'G, resulting in 
linear in N space and time complexity. Used with kind permission of Krzysztof Choromanskt. 


We can create an approximation to A by using a single sample of the random features @(q;) and 
ġ(k;), and using a small value of M, say M = O(Dlog(D)). We can then approximate the entire 
attention operator in O(N) time using 


attention(Q, K, V) = diag '(Q'((K’)"1w))(Q'((K’)'V)) (15.69) 


This can be shown to be an unbiased approximation to the exact softmax attention operator. See 
Figure 15.31 for an illustration. (For details on how to generalize this to masked (causal) attention, 
see [Cho+20a].) 


15.7 Language models and unsupervised representation learning 


We have discussed how RNNs and autoregressive (decoder-only) transformers can be used as language 
models, which are generative sequence models of the form p(z1,..., £r) = IŁ, plxil£1:+—1), where 
each x; is a discrete token, such as a word or wordpiece. (See Section 1.5.4 for a discussion of 
text preprocessing methods.) The latent state of these models can then be used as a continuous 
vector representation of the text. That is, instead of using the one-hot vector æ+, or a learned 
embedding of it (such as those discussed in Section 20.5), we use the hidden state h;, which depends 
on all the previous words in the sentence. These vectors can then be used as contextual word 
embeddings, for purposes such as text classification or seq2seq tasks (see e.g. [LKB20] for a review). 
The advantage of this approach is that we can pre-train the language model in an unsupervised 
way, on a large corpus of text, and then we can fine-tune the model in a supervised way on a small 
labeled task-specific dataset. (This general approach is called transfer learning, see Section 19.2 
for details.) 

If our primary goal is to compute useful representations for transfer learning, as opposed to 
generating text, we can replace the generative sequence model with non-causal models that can 
compute a representation of a sentence, but cannot generate it. These models have the advantage 
that now the hidden state h; can depend on the past, y,.4-1, present y+, and future, Yt+ı:r. This 
can sometimes result in better representations, since it takes into account more context. 

In the sections below, we briefly discuss some unsupervised models for representation learning on 
text, using both causal and non-causal models. 
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me 


KETI Ty 


Figure 15.82: Illustration of ELMo bidrectional language model. Here yt = «141 when acting as the target for 
the forwards LSTM, and yz = xt-1 for the backwards LSTM. (We add bos and eos sentinels to handle the 
edge cases.) From [Wen19]. Used with kind permission of Lilian Weng. 


15.7.1 ELMo 


In [Pet+18], they present a method called ELMo, which is short for “Embeddings from Language 
Model”. The basic idea is to fit two RNN language models, one left-to-right, and one right-to-left, 
and then to combine their hidden state representations to come up with an embedding for each 
word. Unlike a biRNN (Section 15.2.2), which needs an input-output pair, ELMo is trained in an 
unsupervised way, to minimize the negative log likelihood of the input sentence æ1:r: 


T 
= — X flog p(z:|£1:-1; Oc, 0? , Os) + log p(ws|ae41-7; be, 07, Os)] (15.70) 
t=1 


where Oe are the shared parameters of the embedding layer, 0, are the shared parameters of the 
softmax output layer, and 0?” and 0“ are the parameters of the two RNN models. (They use LSTM 
RNNs, described in Section 15.2.7.2.) See Figure 15.32 for an illustration. 

After training, we define the contextual representation r, = [e:,h;).,,h{4.,], where L is the 
number of layers in the LSTM. We then learn a task-specific set of linear weights to map this to 
the final context-specific embedding of each token: r? = r] w’, where j is the task id. If we are 
performing a syntactic task like part-of-speech (POS) tagging (i.e., labeling each word as a noun, 
verb, adjective, etc), then the task will learn to put more weight on lower layers. If we are performing 
a semantic task like word sense disambiguation (WSD), then the task will learn to put more 
weight on higher layers. In both cases, we only need a small amount of task-specific labeled data, 
since we are just learning a single weight vector, to map from 7.7 to the target labels yı:r. 


15.7.2 BERT 


In this section, we describe the BERT model (Bidirectional Encoder Representations from Transform- 
ers) of [Dev+19]. Like ELMo, this is a non-causal model, that can be used to create representations 
of text, but not to generate text. In particular, it uses a transformer model to map a modified version 
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Figure 15.33: Illustration of (a) BERT and (b) GPT. E, is the embedding vector for the input token at 
location t, and T, is the output target to be predicted. From Figure 3 of [Dev+19]. Used with kind permission 
of Ming- Wei Chang. 


of a sequence back to the unmodified form. The modified input at location t omits all words except 
for the tth, and the task is to predict the missing word. This is called the fill-in-the-blank or 
cloze task. 


15.7.2.1 Masked language model task 


More precisely, the model is trained to minimize the negative log pseudo-likelihood: 


L=EgvpEm X — log p(2i|a_m) (15.71) 


1Em 


where m is a random binary mask. For example, if we train the model on transcripts from cooking 
videos, we might create a training sentence of the form 


Let’s make [MASK] chicken! [SEP] It [MASK] great with orange sauce. 


where [SEP] is a separator token inserted between two sentences. The desired target labels for the 
masked words are “some” and “tastes”. (This example is from [Sun+19al].) 

The conditional probability is given by applying a softmax to the final layer hidden vector at 
location 7: 


exp(h(a)j e(x;)) 


» exp(h(z)! e(x’)) (15.72) 


p(a;|£) = = 


where & = Z_m is the masked input sentence, and e(x) is the embedding for token x. This is used 
to compute the loss at the masked locations; this is therefore called a masked language model. 
(This is similar to a denoising autoencoder, Section 20.3.2). See Figure 15.33a for an illustration of 
the model. 


15.7.2.2 Next sentence prediction task 


In addition to the masked language model objective, the original BERT paper added an additional 
objective, in which the model is trained to classify if one sentence follows another. More precisely, 
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Input <cls> this movie is great <sep> i like it <sep> 


Token 
Embeddings @cels> “this Cmovie € great © csop> Ce. arene 
+ + + + $ + + oi P j 
Segment 
i e e e e e e 
Embeddings a A A A ‘A A 
+ + + + + + + + + + 
Positional 
Embeddings 


Figure 15.34: Illustration of how a pair of input sequences, denoted A and B, are encoded before feeding to 
BERT. From Figure 14.8.2 of [Zha+20]. Used with kind permission of Aston Zhang. 


the model is fed as input 


where SEP is a special separator token, and CLS is a special token marking the class. If sentence B 
follows A in the original text, we set the target label to y = 1, but if B is a randomly chosen sentence, 
we set the target label to y = 0. This is called the next sentence prediction task. This kind of 
pre-training can be useful for sentence-pair classification tasks, such as textual entailment or textual 
similarity, which we discussed in Section 15.4.6. (Note that this kind of pre-training is considered 
unsupervised, or self-supervised, since the target labels are automatically generated.) 

When performing next sentence prediction, the input to the model is specified using 3 different 
embeddings: one per token, one for each segment label (sentence A or B), and one per location 
(using a learned positional embedding). These are then added. See Figure 15.34 for an illustration. 
BERT then uses a transformer encoder to learn a mapping from this input embedding sequence to 
an output embedding sequence, which gets decoded into word labels (for the masked locations) or a 
class label (for the CLS location). 


15.7.2.3 Fine-tuning BERT for NLP applications 


After pre-training BERT in an unsupervised way, we can use it for various downtream tasks by 
performing supervised fine-tuning. (See Section 19.2 for more background on such transfer learning 
methods.) Figure 15.35 illustrates how we can modify a BERT model to perform different tasks, by 
simply adding one or more new output heads to the final hidden layer. See bert _jax.ipynb for some 
sample code. 

In Figure 15.35(a), we show how we can tackle single sentence classification (e.g., sentiment 
analysis): we simply take the feature vector associated with the dummy CLS token and feed it into an 
MLP. Since each output attends to all inputs, this hidden vector will summarize the entire sentence. 
The MLP then learns to map this to the desired label space. 

In Figure 15.35(b), we show how we can tackle sentence-pair classification (e.g., textual entail- 
ment, as discussed in Section 15.4.6): we just feed in the two input sentences, formatted as in 
Equation (15.73), and then classify the CLS token. 

In Figure 15.35(c), we show how we can tackle single sentence tagging, in which we associate a 
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Sentence 1 Sentence 2 
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Start/End Span 
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Question Paragraph 


Single Sentence 


(c) (d) 


Figure 15.85: Illustration of how BERT can be used for different kinds of supervised NLP tasks. (a) Single 
sentence classification (e.g., sentiment analysis); (b) Sentence-pair classification (e.g., textual entailment); 
(c) Single sentence tagging (e.g., shallow parsing); (d) Question answering. From Figure 4 of [Dev+19]. Used 
with kind permission of Ming-Wei Chang. 


label or tag with each word, instead of just the entire sentence. A common application of this is part 
of speech tagging, in which we annotate each words a noun, verb, adjective, etc. Another application 
of this is noun phrase chunking, also called shallow parsing, in which we must annotate the 
span of each noun phrase. The span is encoded using the BIO notation, in which B is the beginning 
of an entity, I-x is for inside, and O is for outside any entity. For example, consider the following 
sentence: 


B I 0 0 0 B I 0 B I I 
British Airways rose after announcing its withdrawl from the UAI deal 


We see that there are 3 noun phrases, “British Airways”, “its withdrawl” and “the UAI deal”. (We 
require that the B, I and O labels occur in order, so this a prior constraint that can be included in 
the model.) 

We can also associate types with each noun phrase, for example distinguishing person, location, 
organization, and other. Thus the label space becomes {B-Per, I-Per, B-Loc, I-Loc, B-Org, I-Org, 
Outside }. This is called named entity recognition, and is a key step in information extraction. 
For example, consider the following sentence: 


BP IP 0 0 0 BL IL BP 0 0o 0 0 
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Mrs Green spoke today in New York. Green chairs the finance committee. 


From this, we infer that the first sentence has two named entities, namely “Mrs Green” (of type 
Person) and “New York” (of type Location). The second sentence mentions another person, “Green”, 
that most likely is the same as the first person, although this across-sentence entity resolution is not 
part of the basic NER task. 

Finally, in Figure 15.35(d), we show how we can tackle question answering. Here the first input 
sentence is the question, the second is the background text, and the output is required to specifying 
the start and end locations of the relevant part of the background that contains the answer (see 
Table 1.4). The start location s and end location e are computed by applying 2 different MLPs to a 
pooled version of the output encodings for the background text; the output of the MLPs is a softmax 
over all locations. At test time, we can extract the span (i,j) which maximizes the sum of scores 
si +e; fori < j. 

BERT achieves state-of-the-art performance on many NLP tasks. Interestingly, [TDP 19] shows 
that BERT implicitly rediscovers the standard NLP pipeline, in which different layers perform tasks 
such as part of speech (POS) tagging, parsing, named entity relationship (NER) detection, semantic 
role labeling (SRL), coreference resolution, etc. More details on NLP can be found in [JM20]. 


15.7.3 GPT 


In [Rad+18], they propose a model called GPT, which is short for “Generative Pre-training Trans- 
former”. This is a causal (generative) model, that uses a masked transformer as the decoder. See 
Figure 15.33b for an illustration. 

In the original GPT paper, they jointly optimize on a large unlabeled dataset, and a small la- 
beled dataset. In the classification setting, the loss is given by £L = Les + ALtM, where Les = 
— Ž (æ ujen, 08 p(y|æ) is the classification loss on the labeled data, and Lum = — igedy, Lot P(t|€1:-1) 
is the language modeling loss on the unlabeled data. 

In [Rad+19], they propose GPT-2, which is a larger version of GPT, trained on a large web 
corpus called WebText. They also eliminate any task-specific training, and instead just train it 
as a language model. The GPT-3 [Bro+20] model is an even larger version of GPT-2, but based 
on the same principles. More recently, OpenAI released ChatGPT [Ope], which is an improved 
version of GPT-3 which has been trained to have interactive dialogs by using a technique called 
reinforcement learning from human feedback or RLHF, a technique first introduced in the 
InstructGPT paper [Ouy+22]. This uses reinforcement learning techniques to fine tune the model 
so that it generates responses that are more “aligned” with human intent, as estimated by a ranking 
model, which is pre-trained on supervised data. 


15.7.3.1 Applications of GPT 


GPT can generate text given an initial input prompt. The prompt can specify a task; if the 
generated response fulfills the task “out of the box”, we say the model is performing zero-shot task 
transfer (see Section 19.6 for details). For example, to perform abstractive summarization of 
some input text 21.7 (as opposed to extractive summarization, which just selects a subset of the 
input words), we sample from p(£r+1:r+100|[£1:r; TL;DR]), where TL;DR is a special token added 
to the end of the input text, which tells the system the user wants a summary. TL;DR stands for 
“too long; didn’t read” and frequently occurs in webtext followed by a human-created summary. By 
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"translate English to German: That is good." 
"cola sentence: The 
course is jumping well." 


"stsb sentencel: The rhino grazed 
on the grass. sentence2: A rhino 
is grazing in a field." 


"Das ist gut.” 
"not acceptable" 


3.8 
"six people hospitalized after 
a storm in attala county." 


Figure 15.36: Illustration of how the T5 model (“Text-to-text Transfer Transformer”) can be used to perform 
multiple NLP tasks, such as translating English to German; determining if a sentence is linguistic valid or 
not (CoLA stands for “Corpus of Linguistic Acceptability”); determining the degree of semantic similarity 
(STSB stands for “Semantic Textual Similarity Benchmark”); and abstractive summarization. From Figure 1 
of [Raf+20]. Used with kind permission of Colin Raffel. 


"summarize: state authorities 
dispatched emergency crews tuesday to 
survey the damage after an onslaught 
of severe weather in mississippi..." 


adding this token to the input, the user hopes to “trigger” the transformer decoder into a state in 
which it enters summarization mode. (This is an example of “prompt engineering”.) However, an 
arguably better way to tell the model what task to perform is to train it on input-output pairs, as 
discussed in Section 15.7.4. 

GPT can also be used to create chatbots, such as ChatGPT [Ope], and for code generation 
(see e.g., [HBK23]). 


15.7.4 T5 


Many models are trained in an unsupervised way, and then fine-tuned on specific tasks. It is also 
possible to train a single model to perform multiple tasks, by telling the system what task to perform 
as part of the input sentence, and then training it as a seq2seq model, as illustrated in Figure 15.36. 
This is the approach used in T5 [Raf+20], which stands for “Text-to-text Transfer Transformer”. The 
model is a standard seq2seq transformer, that is pretrained on unsupervised (x’,x’’) pairs, where a’ 
is a masked version of x and x” are the missing tokens that need to be predicted, and then fine-tuned 
on multiple supervised (a, y) pairs. 

The unsupervised data comes from C4, or the “Colossal Clean Crawled Corpus”, a 750GB corpus 
of web text. This is used for pretraining using a BERT-like denoising objective. For example, the 
sentence x =“Thank you for inviting me to your party last week” may get converted to the input 
a’ = “Thank you <X> me to your party <Y> week” and the output (target) x” =“<X> for inviting 
<Y> last <EOS>”, where < X > and < Y > are tokens that are unique to this example. 

The supervised datasets are manually created, and are taken from the literature. Recently the 
FLAN-T5 model [Chu+22] was released, which uses instruction fine-tuning on over 1800 such 
tasks, including language translation, text classification, and question answering. The resulting model 
is currently the state-of-the-art on many NLP tasks. 


15.7.5 Discussion 


Large language models or LLMs, such as BERT and GPT-3, have recently generated a lot of 
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interest, and have even made their way into the mainstream media.’ However, there is some doubt 
about whether such systems “understand” language in any meaningful way, beyond just rearranging 
word patterns seen in their massive training sets. For example, [NK19] show that the ability of BERT 
to perform almost as well as humans on the Argument Reasoning Comprehension Task is “entirely 
accounted for by exploitation of spurious statistical cues in the dataset”. By slightly tweaking the 
dataset, performance can be reduced to chance levels. For other criticisms of such models, see e.g., 
[BK20; Mar20; Dzi+23; Mah+23]. 


7. See e.g., https: //www.nytimes.com/2020/11/24/science/artificial-intelligence-ai-gpt3.html. 
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Nonparametric Models 


1 6 Exemplar-based Methods 


So far in this book, we have mostly focused on parametric models, either unconditional p(y|@) 
or conditional p(y|x,@), where @ is a fixed-dimensional vector of parameters. The parameters are 
estimated from a variable-sized dataset, D = {(@n, yn) : n =1: N}, but after model fitting, the data 
is thrown away. 

In this section we consider various kinds of nonparametric models, that keep the training data 
around. Thus the effective number of parameters of the model can grow with |D|. We focus on 
models that can be defined in terms of the similarity between a test input, x, and each of the 
training inputs, £n. Alternatively, we can define the models in terms of a dissimilarity or distance 
function d(x, £n). Since the models keep the training examples around at test time, we call them 
exemplar-based models. (This approach is also called instance-based learning [AKA91], or 
memory-based learning.) 


16.1 K nearest neighbor (KNN) classification 


In this section, we discuss one of the simplest kind of classifier, known as the K nearest neighbor 
(KNN) classifier. The idea is as follows: to classify a new input a, we find the K closest examples 
to x in the training set, denoted Nx(x,P), and then look at their labels, to derive a distribution 
over the outputs for the local region around æ. More precisely, we compute 


1 
pu=dz,D)=> E Tun =e) (16.1) 
ne€Nx(x,D) 


We can then return this distribution, or the majority label. 
The two main parameters in the model are the size of the neighborhood, K, and the distance 
metric d(x, a’). For the latter, it is common to use the Mahalanobis distance 


dys (a, u) = y (æ — 4)"M(@ — u) (16.2) 


where M is a positive definite matrix. If M = I, this reduces to Euclidean distance. We discuss how 
to learn the distance metric in Section 16.2. 

Despite the simplicity of KNN classifiers, it can be shown that this approach gets within a factor 
of 2 of the Bayes error (which measures the performance of the best possible classifier) as N — oo 
[CH67; CD14]. (Of course the convergence rate to this optimal performance may be poor in practice, 
for reasons we discuss in Section 16.1.2.) 
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(a) (b) 


Figure 16.1: (a) Illustration of a K-nearest neighbors classifier in 2d for K = 5. The nearest neighbors 
of test point x have labels {1,1,1,0,0}, so we predict p(y = 1|x, D) = 3/5. (b) Illustration of the Voronoi 
tessellation induced by 1-NN. Adapted from Figure 4.18 of [DHS01]. Generated by knn_voronoi_plot.ipynb. 


16.1.1 Example 
We illustrate the KNN classifier in 2d in Figure 16.1(a) for K = 5. The test point is marked as 


an “x”. 3 of the 5 nearest neighbors have label 1, and 2 of the 5 have label 0. Hence we predict 
ply = ljx,D) = 3/5 = 0.6. 

If we use K = 1, we just return the label of the nearest neighbor, so the predictive distribution 
becomes a delta function. A KNN classifier with K = 1 induces a Voronoi tessellation of the 
points (see Figure 16.1(b)). This is a partition of space which associates a region V (æn) with each 
point £n in such a way that all points in V(a#,,) are closer to x, than to any other point. Within 
each cell, the predicted label is the label of the corresponding training point. Thus the training error 
will be 0 when K = 1. However, such a model is usually overfitting the training set, as we show 
below. 

Figure 16.2 gives an example of KNN applied to a 2d dataset, in which we have three classes. We 
see how, with K = 1, the method makes zero errors on the training set. As K increases, the decision 
boundaries become smoother (since we are averaging over larger neighborhoods), so the training 
error increases, as we start to underfit. This is shown in Figure 16.2(d). The test error shows the 
usual U-shaped curve. 


16.1.2 The curse of dimensionality 


The main statistical problem with KNN classifiers is that they do not work well with high dimensional 
inputs, due to the curse of dimensionality. 

The basic problem is that the volume of space grows exponentially fast with dimension, so you 
might have to look quite far away in space to find your nearest neighbor. To make this more precise, 
consider this example from [HTF09, p22]. Suppose we apply a KNN classifier to data where the 
inputs are uniformly distributed in the D-dimensional unit cube. Suppose we estimate the density of 
class labels around a test point a by “growing” a hyper-cube around x until it contains a desired 
fraction p of the data points. The expected edge length of this cube will be ep(s) £ p!/?; this 
function is plotted in Figure 16.3(b). If D = 10, and we want to base our estimate on 10% of the 
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Figure 16.2: Decision boundaries induced by a KNN classifier. (a) K = 1. 
and test error vs K. Generated by knn classify demo.ipynb. 
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Figure 16.3: Illustration of the curse of dimensionality. (a) We embed a small cube of side s inside 


a larger unit cube. 


(b) We plot the edge length of a cube needed to cover a given volume of the unit 


cube as a function of the number of dimensions. Adapted from Figure 2.6 from [HTF09]. Generated by 
curse_ dimensionality plot.ipynb. 
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data, we have e19(0.1) = 0.8, so we need to extend the cube 80% along each dimension around a. 
Even if we only use 1% of the data, we find e10(0.01) = 0.63. Since the range of the data is only 0 to 
1 along each dimension, we see that the method is no longer very local, despite the name “nearest 
neighbor”. The trouble with looking at neighbors that are so far away is that they may not be good 
predictors about the behavior of the function at a given point. 

There are two main solutions to the curse: make some assumptions about the form of the function 
(i.e., use a parametric model), and/or use a metric that only cares about a subset of the dimensions 
(see Section 16.2). 


16.1.3 Reducing the speed and memory requirements 


KNN classifiers store all the training data. This is obviously very wasteful of space. Various heuristic 
pruning techniques have been proposed to remove points that do not affect the decision boundaries, 
see e.g., [WMO00]. In Section 17.4, we discuss a more principled approach based on a sparsity 
promoting prior; the resulting method is called a sparse kernel machine, and only keeps a subset of 
the most useful exemplars. 

In terms of running time, the challenge is to find the K nearest neighbors in less than O(N) 
time, where N is the size of the training set. Finding exact nearest neighbors is computationally 
intractable when the dimensionality of the space goes above about 10 dimensions, so most methods 
focus on finding the approximate nearest neighbors. There are two main classes of techniques, based 
on partitioning space into regions, or using hashing. 

For partitioning methods, one can either use some kind of k-d tree, which divides space into 
axis-parallel regions, or some kind of clustering method, which uses anchor points. For hashing 
methods, locality sensitive hashing (LSH) [GIM99] is widely used, although more recent methods 
learn the hashing function from data (see e.g., [Wan-+15]). See [LRU14] for a good introduction to 
hashing methods. 

An open-source library called FAISS, for efficient exact and approximate nearest neighbor search 
(and K-means clustering) of dense vectors, is available at https: //github.com/facebookresearch/ 
faiss, and described in [JDJ17]. 


16.1.4 Open set recognition 
Ask not what this is called, ask what this is like. — Moshe Bar.[Bar09] 


In all of the classification problems we have considered so far, we have assumed that the set of 
classes C is fixed. (This is an example of the closed world assumption, which assumes there is a 
fixed number of (types of) things.) However, many real world problems involve test samples that 
come from new categories. This is called open set recognition, as we discuss below. 


16.1.4.1 Online learning, OOD detection and open set recognition 


For example, suppose we train a face recognition system to predict the identity of a person from a 
fixed set or gallery of face images. Let Di = { (£n, Yn) : En E X, Yn E Ct, n = 1: Ni} be the labeled 
dataset at time t, where æ is the set of (face) images, and C; = {1,...,C;} is the set of people known 
to the system at time t (where C; < t). At test time, the system may encounter a new person that 
it has not seen before. Let æ+} be this new image, and yz, = C;4, be its new label. The system 
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needs to recognize that the input is from a new category, and not accidentally classify it with a 
label from C+. This is called novelty detection. In this case, the input is being generated from the 
distribution p(a|y = Cr41), where Ct+1 ¢ Ci is the new “class label”. Detecting that a4, is from a 
novel class may be hard if the appearance of this new image is similar to the appearance of any of 
the existing images in D;. 

If the system is successful at detecting that 2,41 is novel, then it may ask for the id of this new 
instance, call it Cz41. It can then add the labeled pair (441, Ct+1) to the dataset to create D;41, and 
can grow the set of unique classes by adding C41 to C; (c.f., [JK13]). This is called incremental 
learning, online learning, life-long learning, or continual learning. At future time points, 
the system may encounter an image sampled from p(a|y = c), where c is an existing class, or where c 
is a new class, or the image may be sampled from some entirely different kind of distribution p’(x) 
unrelated to faces (e.g., someone uploads a photo of their dog). (Detecting this latter kind of event is 
called out-of-distribution or OOD detection.) 

In this online setting, we often only get a few (sometimes just one) example of each class. Prediction 
in this setting is known as few-shot classification, and is discussed in more detail in Section 19.6. 
KNN classifiers are well-suited to this task. For example, we can just store all the instances of each 
class in a gallery of examples, as we explained above. At time t + 1, when we get input 241, rather 
than predicting a label for x44; by comparing it to some parametric model for each class, we just 
find the example in the gallery that is nearest (most similar) to #41, call it æ’. We then need to 
determine if x’ and 2,41 are sufficiently similar to constitute a match. (In the context of person 
classification, this is known as person re-identification or face verification, see e.g., [WSH16]).) 
If there is no match, we can declare the input to be novel or OOD. 

The key ingredient for all of the above problems is the (dis)similarity metric between inputs. We 
discuss ways to learn this in Section 16.2. 


16.1.4.2 Other open world problems 


The problem of open-set recognition, and incremental learning, are just examples of problems that 
require the open world assumption c.f., [Rus15]. There are many other examples of such problems. 

For example, consider the problem of entity resolution, called entity linking. In this problem, 
we need to determine if different strings (e.g., “John Smith” and “Jon Smith”) refer to the same entity 
or not. See e.g. [SHF 15] for details. 

Another important application is in multi-object tracking. For example, when a radar system 
detects a new “blip”, is it due to an existing missile that is being tracked, or is it a new objective 
that has entered the airspace? An elegant mathematical framework for dealing with such problems, 
known as random finite sets, is described in [Mah07; Mah13; Vo+15]. 


16.2 Learning distance metrics 


Being able to compute the “semantic distance” between a pair of points, d(a,a’) € Rt for x, 2’ € X, 
or equivalently their similarity s(x,a’) € Rt, is of crucial importance to tasks such as nearest 
neighbor classification (Section 16.1), self-supervised learning (Section 19.2.4.4), similarity-based 
clustering (Section 21.5), content-based retrieval, visual tracking, etc. 
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When the input space is X = RP, the most common distance metric is the Mahalanobis distance 


dua(@, 2") = \/(e — 2’)'M(a — 2’) (16.3) 


We discuss some methods to learn the matrix M in Section 16.2.1. For high dimensional inputs, or 
structured inputs, it is better to first learn an embedding e = f(x), and then to compute distances 
in embedding space. When f is a DNN, this is called deep metric learning; we discuss this in 
Section 16.2.2. 


16.2.1 Linear and convex methods 


In this section, we discuss some methods that try to learn the Mahalanobis distance matrix M, either 
directly (as a convex problem), or indirectly via a linear projection. For other approaches to metric 
learning, see e.g., [Kull13; Kim19] for more details. 


16.2.1.1 Large margin nearest neighbors 


In [WS09], they propose to learn the Mahalanobis matrix M so that the resulting distance metric 
works well when used by a nearest neighbor classifier. The resulting method is called large margin 
nearest neighbor or LMNN. 

This works as follows. For each example data point i, let N; be a set of target neighbors; these 
are usually chosen to be the set of K points with the same class label that are closest in Euclidean 
distance. We now optimize M so that we minimize the distance between each point 7 and all of its 
target neighbors j € N;: 

N 
Loun(M) = X` X dmz; 3)? (16.4) 


i=1 JEN; 


We also want to ensure that examples with incorrect labels are far away. To do this, we ensure that 
each example 7 is closer (by some margin m > 0) to its target neighbors j than to other points l with 
different labels (so-called impostors). We can do this by minimizing 


N N 
Lyusn(M) =X So SUT (ys A ys) [m + doa (wi, £j)? — da (wi, 1)"]+ (16.5) 


i=1 JEN; l=1 


where [z], = max(z,0) is the hinge loss function (Section 4.3.2). The overall objective is £(M) = 
(1 — A)£Lpun(M) + ALpush(M), where 0 < A < 1. This is a convex function defined over a convex set, 
which can be minimized using semidefinite programming. Alternatively, we can parameterize 
the problem using M = WTW, and then minimize wrt W using unconstrained gradient methods. 
This is no longer convex, but allows us to use a low-dimensional mapping W. 

For large datasets, we need to tackle the O(N?) cost of computing Equation (16.5). We discuss 
some speedup tricks in Section 16.2.5. 


16.2.1.2 Neighborhood components analysis 


Another way to learn a linear mapping W such that M = W!W is known as neighborhood 
components analysis or NCA [Gol+05]. This defines the probability that sample æ; has æ; as its 
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Figure 16.4: Illustration of latent coincidence analysis (LCA) as a directed graphical model. The inputs 
x,a’ € R? are mapped into Gaussian latent variables z, z' € R? via a linear mapping W. If the two latent 
points coincide (within length scale n) then we set the similarity label to y = 1, otherwise we set it to y = 0. 
From Figure 1 of [D512]. Used with kind permission of Lawrence Saul. 


nearest neighbor using the linear softmax function 


—||Wa; — Wa,||2 
pw = exp(—||Wa x;\|3) f (16.6) 
Dizi exp(—||Wa; — Wa:||3) 


(This is a supervised version of stochastic neighborhood embeddings discussed in Section 20.4.10.1.) 
The expected number of correctly classified examples according for a 1NN classifier using distance 
W is given by J(W) = yE Shean py. Let L(W) = 1 — J(W)/N be the leave one out error. 
We can minimize £ wrt W using gradient methods. 


16.2.1.3 Latent coincidence analysis 


Yet another way to learn a linear mapping W such that M = WTW is known as latent coincidence 
analysis or LCA [DS12]. This defines a conditional latent variable model for mapping a pair of 
inputs, x and 2’, to a label y € {0,1}, which specifies if the inputs are similar (e.g., have same class 
label) or dissimilar. Each input æ € RP is mapped to a low dimensional latent point z € R” using 
a stochastic mapping p(z|”) = NV(z|W~2, 071), and p(z'|”2’) = N(z’|Wz2’',o7I). (Compare this to 
factor analysis, discussed in Section 20.2.) We then define the probability that the two inputs are 
similar using p(y = 1|z, 2’) = exp(—54s||z — 2’||). See Figure 16.4 for an illustration of the modeling 
assumptions. 

We can maximize the log marginal likelihood (W, 0°, k?) = ©, log p(yn|@n, £!) using the EM 
algorithm (Section 8.7.2). (We can set k = 1 WLOG, since it just changes the scale of W.) More 
precisely, in the E step, we compute the posterior p(z, z’|x,x’,y) (which can be done in closed 
form), and in the M step, we solve a weighted least squares problem (c.f., Section 13.6.2). EM will 
monotonically increase the objective, and does not need step size adjustment, unlike the gradient based 
methods used in NCA (Section 16.2.1.2). (It is also possible to use variational Bayes (Section 4.6.8.3) 
to fit this model, as well as various sparse and nonlinear extensions, as discussed in [ZMY19].) 
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16.2.2 Deep metric learning 


When measuring the distance between high-dimensional or structured inputs, it is very useful to first 
learn an embedding to a lower dimensional “semantic” space, where distances are more meaningful, and 
less subject to the curse of dimensionality (Section 16.1.2). Let e = f(x; 0) € R? be an embedding 
of the input that preserves the “relevant” semantic aspects of the input, and let ê = e/||e||2 be the 
f5-normalized version. This ensures that all points lie on a hyper-sphere. We can then measure the 
distance between two points using the normalized Euclidean distance 


d(xi, £j; 0) = |lê; — €5||3 (16.7) 
where smaller values means more similar, or the cosine similarity 
d(x;,2;;0) = êlê; (16.8) 


where larger values means more similar. (Cosine similarity measures the angle between the two 
vectors, as illustrated in Figure 20.43.) These quantities are related via 


Ilê; — €j||3 = (êi — êj)" (ê; — êj) = 2 — 2êj êj (16.9) 


This overall approach is called deep metric learning or DML. 

The basic idea in DML is to learn the embedding function such that similar examples are closer than 
dissimilar examples. More precisely, we assume we have a labeled dataset, D = { (x; yi): i= 1: N}, 
from which we can derive a set of similar pairs, S = {(i, j) : yi = yj}. If (i,j) € S but (i,k) g S, 
then we assume that x; and æ; should be close in embedding space, whereas x; and x, should be 
far. We discuss various ways to enforce this property below. Note that these methods also work 
when we do not have class labels, provided we have some other way of defining similar pairs. For 
example, in Section 19.2.4.3, we discuss self-supervised approaches to representation learning, that 
automatically create semantically similar pairs, and learn embeddings to force these pairs to be closer 
than unrelated pairs. 

Before discussing DML in more detail, it is worth mentioning that many recent approaches to 
DML are not as good as they claim to be, as pointed out in [MBL20; Rot+20]. (The claims in 
some of these papers are often invalid due to improper experimental comparisons, a common flaw in 
contemporary ML research, as discussed in e.g., [BLV19; LS19b].) We therefore focus on (slightly) 
older and simpler methods, that tend to be more robust. 


16.2.3 Classification losses 


Suppose we have labeled data with C classes. Then we can fit a classification model in O(NC) time, 
and then reuse the hidden features as an embedding function. (It is common to use the second-to-last 
layer, since it generalizes better to new classes than the final layer.) This approach is simple and 
scalable. However, it only learns to embed examples on the correct side of a decision boundary, which 
does not necessarily result in similar examples being placed close together and dissimilar examples 
being placed far apart. In addition, this method cannot be used if we do not have labeled training 
data. 
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The Siamese Network 
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Figure 16.5: Networks for deep metric learning. (a) Siamese network. (b) Triplet network. Adapted from 
Figure 5 of [KB19]. 


16.2.4 Ranking losses 


In this section, we consider minimizing ranking loss, to ensure that similar examples are closer 
than dissimilar examples. Most of these methods do not need class labels (although we sometimes 
assume that labels exist as a notationally simple way to define similarity). 


16.2.4.1 Pairwise (contrastive) loss and Siamese networks 


One of the earliest approaches to representation learning from similar/dissimilar pairs was based on 
minimizing the following contrastive loss [CHL05]: 


L(0; £i, £j) = 1 (yi = yj) d(ag, £3)? +1 (ys A yy) [m — d(x, æ) (16.10) 


where [z]4 = max(0, z) is the hinge loss and m > 0 is a margin parameter. Intuitively, we want to 
force positive pairs (with the same label) to be close, and negative pairs (with different labels) to be 
further apart than some minimal safety margin. We minimize this loss over all pairs of data. Naively 
this takes O(N?) time; see Section 16.2.5 for some speedups. 

Note that we use the same feature extractor f(-;@) for both inputs, x; and æj. when computing 
the distance, as illustrated in Figure 16.5a. The resulting network is therefore called a Siamese 
network (named after Siamese twins). 


16.2.4.2 Triplet loss 


One disadvantage of pairwise losses is that the optimization of the positive pairs is independent of 
the negative pairs, which can make their magnitudes incomparable. A solution to this is to use the 
triplet loss [SKP15]. This is defined as follows. For each example i (known as an anchor), we 
find a similar (positive) example x7 and a dissimilar (negative) example 2; . We then minimize the 
following loss, averaged overall all triples: 


L(0; £i, "7,2; ) = [de(a;, £7)? — dolz, 2; )? + my (16.11) 
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Intuitively this says we want the distance from the anchor to the positive to be less (by some safety 
margin m) than the distance from the anchor to the negative. We can compute the triplet loss using 
a triplet network as shown in Figure 16.5b. 

Naively minimizing triplet loss takes O(N?) time. In practice we compute the loss on a minibatch 
(chosen so that there is at least one similar and one dissimilar example for the anchor point, often 
taken to be the first entry in the minibatch). Nevertheless the method can be slow. We discuss some 
speedups in Section 16.2.5. 


16.2.4.3 N-pairs loss 


One problem with the triplet loss is that each anchor is only compared to one negative example at a 
time. This might not provide a strong enough learning signal. One solution to this is to create a 
multi-class classification problem in which we create a set of N — 1 negatives and 1 positive for every 
anchor. This is called the N-pairs loss [Soh16]. More precisely, we define the following loss for each 
set: 


L(0; 2,27, ta) = log (: + 
k=1 


N-1 
X` exp(€o(x)"éo(a@;, ))| — clea" eale’) ) (16.12) 


exp(€9(x)" €9(a*)) 
exp(êo(x) èo (æ+)) + Epa exp(€o(x)" €o(a;, )) 


Note that the N-pairs loss is the same as the InfoNCE loss used in the CPC paper [OLV18]. In 
[Che+20a], they propose a version where they scale the similarities by a temperature term; they call 
this the NT-Xent (normalized temperature-scaled cross-entropy) loss. We can view the temperature 
parameter as scaling the radius of the hypersphere on which the data lives. 

When N = 2, the loss reduces to the logistic loss 


= — log (16.13) 


L(0; x, £t, £7) = log (1 + exp(êo(x)" èo(£7) — êol(x)"êolæt))) (16.14) 
Compare this to the margin loss used by triplet learning (when m = 1): 
L(0;x, £, £7) = max (0,ê(æ)'ê(x7) — ê(x)'è(æt) + 1) (16.15) 


See Figure 4.2 for a comparison of these two functions. 


16.2.5 Speeding up ranking loss optimization 


The main disadvantage of ranking loss is the O(N?) or O(N°) cost of computing the loss function, 
due to the need to compare all pairs or triples of examples. In this section, we discuss various speedup 
tricks. 


16.2.5.1 Mining techniques 


A key insight is that we don’t need to consider all negative examples for each anchor, since most will 
be uninformative (i-e., will incur zero loss). Instead we can focus attention on negative examples 
which are closer to the anchor than its nearest positive example. These are called hard negatives, 
and are particularly useful for speeding up triplet loss. 
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n3 Easy negative: d(a,p) + m < d(a,n3) 


Triplet loss Proposed upper bound to the triplet loss 
negative: d(a,p) < d(a,n2) < d(a,p) +m 
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Figure 16.6: Speeding up triplet loss minimization. (a) Illustration of hard vs easy negatives. Here a is the 
anchor point, p is a positive point, and n; are negative points. Adapted from Figure 4 of [KB19]. (b) Standard 
triplet loss would take 8 x 3 x 4 = 96 calculations, whereas using a prozy loss (with one proxy per class) takes 
8 x 2 = 16 calculations. From Figure 1 of [Do+19]. Used with kind permission of Gustavo Cerneiro. 


More precisely, if a is an anchor and p is its nearest positive example, we say that n is a hard 
negative (for a) if d(£a, En) < d(La, £p) and yn A Ya. Sometimes an anchor may not have any hard 
negatives. We can therefore increase the pool of candidates by considering semi-hard negatives, 
for which 


d(La, Ep) < d(La, En) < dLa, £p) +m (16.16) 


where m > 0 is a margin parameter. See Figure 16.6a for an illustration. This is the technique used 
by Google’s FaceNet model [SKP15], which learns an embedding function for faces, so it can cluster 
similar looking faces together, to which the user can attach a name. 

In practice, the hard negatives are usually chosen from within the minibatch. This therefore 
requires large batch sizes to ensure sufficient diversity. Alternatively, we can have a separate process 
that continually updates the set of candidate hard negatives, as the distance measure evolves during 
training. 


16.2.5.2 Proxy methods 


Triplet loss minimization is expensive even with hard negative mining (Section 16.2.5.1). Ideally we 
can find a method that is O(N) time, just like classification loss. 

One such method, proposed in [MA+17], measures the distance between each anchor and a set 
of P proxies that represent each class, rather than directly measuring distance between examples. 
These proxies need to be updated online as the distance metric evolves during learning. The overall 
procedure takes O(N P?) time, where P ~ C. 

More recently, [Qia+19] proposed to represent each class with multiple prototypes, while still 
achieving linear time complexity, using a soft triple loss. 


16.2.5.3 Optimizing an upper bound 


[Do+19] proposed a simple and fast method for optimizing the triplet loss. The key idea is to define 
one fixed proxy or centroid per class, and then to use distance to the proxy as an upper bound on 
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the triplet loss. 
More precisely, consider a simplified form of the triplet loss, without the margin term: 


li(£i, £j, 2p) = |ê; — €5|| — Ilê: — êl] (16.17) 


where ê; = êọ(x;), etc. Using the triangle inequality we have 


[lê — ê;|| < Ilê: — ey, || + le; — ey, (16.18) 

[lê; — €x|| > lê: — ey, |] — Ilêk — cya l| (16.19) 
Hence 

l, (a5, Ej, Lk) < Ly (as, Ej, Lk) È |ê; — ey,|| — Ilê: — ey, || + Ilê; — cyl] + [lex — Cue ll (16.20) 


We can use this to derive a tractable upper bound on the triplet loss as follows: 


L£1(D,S) = 5 li(£i, £j, £k) < 5 bil Dp £j, £k) 


(AES (i, k)ZS,i,j,kE{1,.. N} (LIJES, (i, k)ZS,i,j,kE{1,.-,N} 
(16.21) 
N 1 c 
= 0! r S A 


where C” = 3(C — 1)(4 — 1)x is a constant, and we assume that ae I(y; = c) = N/C for each c. 
It is clear that £,, can be computed in O(NC) time. See Figure 16.6b for an illustration. 

In [Do+19], they show that 0 < Li — Ly < NOK, where K is some constant that depends on the 
spread of the centroids. To ensure the bound is tight, the centroids should be as far from each other 
as possible, and the distances between them should be as similar as possible. An easy way to ensure 
is to define the Cm vectors to be one-hot vectors, one per class. These vectors already have unit 
norm, and are orthogonal to each other. The distance between each pair of centroids is v2, which 
ensures the upper bound is fairly tight. 

The downside of this approach is that it assumes the embedding layer is L = C dimensional. There 
are two solutions to this. First, after training, we can add a linear projection layer to map from C 
to L Æ C, or we can take the second-to-last layer of the embedding network. The second approach 
is to sample a large number of points on the Z-dimensional unit hyper-sphere (which we can do 
by sampling from the standard normal, and then normalizing [Mar72]), and then running K-means 
clustering (Section 21.3) with K = C. In the experiments reported in [Do+19], these two approaches 
give similar results. 

Interestingly, in [Rot-+20], they show that increasing intra/inter results in improved downstream 
performance on various retrieval tasks, where 


C 
Tintra = 5 5 d(x;, £j) (16.23) 


Zint 
nera e=] iŻjiyi=yj=e 


is the average intra-class distance, and 
ILL 


c=1c/=1 


Tinter = 
Z inter 
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Figure 16.7: Adding spherical embedding constraint to a deep metric learning method. Used with kind 
permission of Dingyi Zhang. 


is the average inter-class distance, where p, = A gee êi is the mean embedding for examples 
from class c. This suggests that we should not only keep the centroids far apart (in order to maximize 
the numerator), but we should also prevent examples from getting too close to their centroids (in 
order to minimize the denominator); this latter term is not captured in the method of [Do+19]. 


16.2.6 Other training tricks for DML 


Besides the speedup tricks in Section 16.2.5, there are a lot of other details that are important to get 
right in order to ensure good DML performance. Many of these details are discussed in [MBL20; 
Rot+20]. Here we just briefly mention a few. 

One important issue is how the minibatches are created. In classification problems (at least with 
balanced classes), selecting examples at random from the training set is usually sufficient. However, 
for DML, we need to ensure that each example has some other examples in the minibatch that are 
similar to it, as well as some others that are dissimilar to it. One approach is to use hard mining 
techniques (Section 16.2.5.1). Another idea is to use coreset methods applied to previously learned 
embeddings to select a diverse minibatch at each step [Sin+20]. However, [Rot+20] show that the 
following simple strategy also works well for creating each batch: pick B/n classes, and then pick Ne 
examples randomly from each class, where B is the batch size, and Ne = 2 is a tuning parameter. 

Another important issue is avoiding overfitting. Since most datasets used in the DML literature 
are small, it is standard to use an image classifier, such as GoogLeNet (Section 14.3.3) or ResNet 
(Section 14.3.4), which has been pre-trained on ImageNet, and then to fine-tune the model using the 
DML loss. (See Section 19.2 for more details on this kind of transfer learning.) In addition, it is 
standard to use data augmentation (see Section 19.1). (Indeed, with some self-supervised learning 
methods, data aug is the only way to create similar pairs.) 

In [ZLZ20], they propose to add a spherical embedding constraint (SEC), which is an additional 
batchwise regularization term, which encourages all the examples to have the same norm. That is, 
the regularizer is just the empirical variance of the norms of the (unnormalized) embeddings in that 
batch. See Figure 16.7 for an illustration. This regularizer can be added to any of the existing DML 
losses to modestly improve training speed and stability, as well as final performance, analogously to 
how batchnorm (Section 14.2.4.1) is used. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


558 Chapter 16. Exemplar-based Methods 


16.3 Kernel density estimation (KDE) 


In this section, we consider a form of non-parametric density estimation known as kernel density 
estimation or KDE. This is a form of generative model, since it defines a probability distribution 
p(x) that can be evaluated pointwise, and which can be sampled from to generate new data. 


16.3.1 Density kernels 


Before explaining KDE, we must define what we mean by a “kernel”. This term has several different 
meanings in machine learning and statistics.! In this section, we use a specific kind of kernel which 
we refer to as a density kernel. This is a function K : R > R+ such that f K(x)dx = 1 and 
K(—«) = K(x). This latter symmetry property implies the f xk (x)dx = 0, and hence 


fxe — Zn)dT = Tn (16.25) 


A simple example of such a kernel is the boxcar kernel, which is the uniform distribution within 
the unit interval around the origin: 


K(x) £ 0.51 (|x| < 1) (16.26) 


Another example is the Gaussian kernel: 


K(x) = — e772 16.27 
@- aH (16.27) 


We can control the width of the kernel by introducing a bandwidth parameter h: 
Kyle) & 2x(2) (16.28) 
hv) = h h à 
We can generalize to vector valued inputs by defining a radial basis function or RBF kernel: 


Kn (æ) x Ka (llæll) (16.29) 


In the case of the Gaussian kernel, this becomes 


D 
1 1 
d=1 


Although Gaussian kernels are popular, they have unbounded support. Some alternative kernels, 
which have compact support (which can be computationally faster), are listed in Table 16.1. See 


Figure 16.8 for a plot of these kernel functions. 


1. For a good blog post on this, see https://francisbach.com/cursed-kernels/. 
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Figure 16.8: A comparison of some popular normalized kernels. Generated by smoothingKernelPlot.ipynb. 


Name Definition Compact Smooth Boundaries 
Gaussian K(x) = (2r) 2e% 0 1 1 
Boxcar K(x) = 4I (|z| < 1) 1 0 0 
Epanechnikov kernel K(x) = $(1 — x?)I (|x| < 1) 1 1 0 
Tri-cube kernel K(x) = 20 = |æj’)I (z| <1) 1 1 1 


Table 16.1: List of some popular normalized kernels in 1d. Compact=1 means the function is non-zero 
for a finite range of inputs. Smooth=1 means the function is differentiable over the range of its support. 
Boundaries=1 means the function is also differentiable at the boundaries of its support. 


16.3.2 Parzen window density estimator 


To explain how to use kernels to define a nonparametric density estimate, recall the form of the 
Gaussian mixture model from Section 3.5.1. If we assume a fixed spherical Gaussian covariance and 
uniform mixture weights, we get 


p(x|0) = FĒNI (a|p,, 071) (16.31) 


One problem with this model is that it requires specifying the number K of clusters, as well as their 
locations wą. An alternative to estimating these parameters is to allocate one cluster center per data 
point. In this case, the model becomes 


p(x|@) = FM 2|an,0°1) (16.32) 
We can generalize Equation (16.32) by writing 
p(a|\D) = HK L — Tn) (16.33) 
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unif, h=1.000 unif, h=2.000 


(a) (b) 


gauss, h=1.000 gauss, h=2.000 


(a) 


Figure 16.9: A nonparametric (Parzen) density estimator in 1d estimated from 6 data points, denoted by z. 
Top row: uniform kernel. Bottom row: Gaussian kernel. Left column: bandwidth parameter h = 1. Right 
column: bandwidth parameter h = 2. Adapted from http: // en. wikipedia. org/wiki/Kernel_ density_ 
estimation. Generated by parzen_window_demo2.ipynb. 


where Kp is a density kernel. This is called a Parzen window density estimator, or kernel 
density estimator (KDE). 

The advantage over a parametric model is that no model fitting is required (except for choosing 
h, discussed in Section 16.3.3), and there is no need to pick the number of cluster centers. The 
disadvantage is that the model takes a lot of memory (you need to store all the data) and a lot of 
time to evaluate. 

Figure 16.9 illustrates KDE in 1d for two kinds of kernel. On the top, we use a boxcar kernel; the 
resulting model just counts how many data points land within an interval of size h around each x, 
to get a piecewise constant density. On the bottom, we use a Gaussian kernel, which results in a 
smoother density. 


16.3.3 How to choose the bandwidth parameter 


We see from Figure 16.9 that the bandwidth parameter h has a large effect on the learned distribution. 
We can view this as controlling the complexity of the model. 

In the case of 1d data, where the “true” data generating distribution is assumed to be a Gaussian, 
one can show [BA97a] that the optimal bandwidth for a Gaussian kernel (from the point of view of 
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minimizing frequentist risk) is given by h = ø (se £) re We can compute a robust approximation to the 
standard deviation by first computing the median absolute deviation, median(|x% — median(x)|), 
and then using ¢ = 1.4826 MAD. If we have D dimensions, we can estimate hg separately for each 
dimension, and then set h = (JẸ; ha)!/2. 


16.3.4 From KDE to KNN classification 


In Section 16.1, we discussed the K nearest neighbor classifier as a heuristic approach to classification. 
Interestingly, we can derive it as a generative classifier in which the class conditional densities 
plæ|y = c) are modeled using KDE. Rather than using a fixed bandwidth and counting how many 
data points fall within the hyper-cube centered on a datapoint, we will allow the bandwidth or 
volume to be different for each data point. Specifically, we will “grow” a volume around æ until we 
encounter K data points, regardless of their class label. This is called a balloon kernel density 
estimator [TS92]. Let the resulting volume have size V(a) (this was previously h”), and let there 
be N.(a) examples from class c in this volume. Then we can estimate the class conditional density 
as follows: 


N(x) 


p(xly = c, D) = N.V (æ) 


(16.34) 
where Ne is the total number of examples in class c in the whole data set. If we take the class prior 
to be p(y = c) = Nc/N, then the class posterior is given by 
Ne(w) Ne 
N.V(a) N N-(a) N-(@) 1 
= = = I (yn = ©) (16.35) 
T Ny (æ@) Ny Zo Ne (a) K > 


ce’ Nw V(a) N n€Nx(a,D) 


ply = cx, D) = 


where we used the fact that 5°. N-(#) = K, since we choose a total of K points (regardless of class) 
around every point. This matches Equation (16.1). 


16.3.5 Kernel regression 
Just as KDE can be used for generative classifiers (see Section 16.1), it can also be used for generative 
models for regression, as we discuss below. 


16.3.5.1 Nadaraya-Watson estimator for the mean 


In regression, our goal is to compute the conditional expectation 


Jy p(x, y|D)dy 
J p(x, y|D)dy 


If we use an MVN for p(y, |D), we derive a result which is equivalent to linear regression, as we 
showed in Section 11.2.3.5. However, the assumption that p(y,a|D) is Gaussian is rather limiting. 
We can use KDE to more accurately approximate the joint density p(x, y|D) as follows: 


(16.36) 


z [ylæ, D] = J y p(ylæ, D)dy = 


1 
p(y, x|D) ~ 15 Kal (x g TEn)Ka (y — Yn) (16.37) 


n=1 
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1.0 xx — true 
x x data 
0.8 — estimate 


Figure 16.10: An example of kernel regression in 1d using a Gaussian kernel. Generated by kernelRegression- 
Demo.ipynb. 


Hence 


4 EL Ka(@ — æn) f yKal(y — yn)dy 
b ENa Kala — æn) f Kaly — yo )dy 


z [ylæ, D] = (16.38) 


We can simplify the numerator using the fact that f yXn(y—yn)dy = yn (from Equation (16.25)). We 
can simplify the denominator using the fact that density kernels integrate to one, i.e., f Kn(y—yn)dy = 
1. Thus 


De Kile — En )Yn =~ 
ilyjaz,D| = => = mn Wn (£ 16.39 
[y|x, D] SN Kilwa ay) 2y (x) ( ) 


wn (æ) = Me — tn) 


Da Kr (£ = Tn’) 


We see that the prediction is just a weighted sum of the outputs at the training points, where the 
weights depend on how similar æ is to the stored training points. This method is called kernel 
regression, kernel smoothing, or the Nadaraya-Watson (N-W) model. See Figure 16.10 for an 
example, where we use a Gaussian kernel. 

In Section 17.2.3, we discuss the connection between kernel regression and Gaussian process 
regression. 


(16.40) 


16.3.5.2 Estimator for the variance 


Sometimes it is useful to compute the predictive variance, as well as the predictive mean. We can do 
this by noting that 


V [ylæ, D] = E [y*|x, D] — u(x)? (16.41) 
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where u(x) = E[y|a, D] is the N-W estimate. If we use a Gaussian kernel with variance o°, we can 


compute E |y?|a,D] as follows: 


EiL Ka(w — an) f y?Kaly — yn)dy 


tf 2 
i [y?læ, D] = (16.42) 
Em1 Ka(w — nw) f Kaly — yn )dy 
N 
K v i 2 2 
= ae ~ (x x NG +45) (16.43) 
paz Ki (x g En!) 
where we used the fact that 
[PN ilmo dy = 0° + Yn (16.44) 
Combining Equation (16.43) with Equation (16.41) gives 
N 
V [ylae,D] = 0? + wn (wee — w(x)? (16.45) 
n=1 
This matches Eqn. 8 of [BA10] (modulo the initial o? term). 
16.3.5.3 Locally weighted regression 
We can drop the normalization term from Equation (16.39) to get 
N 
p(x) = X` ynKn(a — an) (16.46) 
n=1 


This is just a weighted sum of the observed responses, where the weights depend on how similar the 
test input æ is to the training points £n. 

Rather than just interpolating the stored responses yn, we can fit a locally linear model around 
each training point: 


p(x) = min X løn — B'b(@n)]? Kn(w — æn) (16.47) 


where $(x) = [1,a]. This is called locally linear regression (LRR) or locally-weighted scat- 
terplot smoothing, and is commonly known by the acronym LOWESS or LOESS [CD88]. This 
is often used when annotating scatter plots with local trend lines. 
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In this chapter, we consider nonparametric methods for regression and classification. Such 
methods do not assume a fixed parametric form for the prediction function, but instead try to 
estimate the function itself (rather than the parameters) directly from data. The key idea is that we 
observe the function value at a fixed set of N points, namely yn = f(£n) for n = 1 : N, where f 
is the unknown function, so to predict the function value at a new point, say £, we just have to 
compare how “similar” x, is to each of the N training points, {£n}, and then we can predict that 
f(x.) is some weighted combination of the {f(x,,)} values. Thus we may need to “remember” the 
entire training set, D = {(@n, Yn)}, in order to make predictions at test time — we cannot “compress’ 
D into a fixed-sized parameter vector. 

The weights that are used for prediction are determined by the similarity between a, and each £n, 
which is computed using a special kind of function known as kernel function, K(a,, £x) > 0, which 
we explain in Section 17.1. This approach is similar to RBF networks (Section 13.6.1), except we use 
the datapoints {£n} themselves as the “anchors”, rather than learning the RBF centroids {puz}. 

In Section 17.2, we discuss an approach called Gaussian processes, which allows us to use the kernel 
to define a prior over functions, which we can update given data to get a posterior over functions. 
Alternatively we can use the same kernel with a method called Support Vector Machines to compute 
a MAP estimate of the function, as we explain in Section 17.3. 


? 


17.1 Mercer kernels 


The key to nonparametric methods is that we need a way to encode prior knowledge about the 
similarity of two input vectors. If we know that æ; is similar to æj, then we can encourage the model 
to make the predicted output at both locations (i.e., f(a;) and f(a,;)) to be similar. 

To define similarity, we introduce the notion of a kernel function. The word “kernel” has many 
different meanings in mathematics, including density kernels (Section 16.3.1), transition kernels of a 
Markov chain (Section 3.6.1.2), and convolutional kernels (Section 14.1). Here we consider a Mercer 
kernel, also called a positive definite kernel. This is any symmetric function K: ¥ x ¥ => Rt 
such that 


N N 
XOY Klai æj)cicj > 0 (17.1) 
i=1 j=l 


for any set of N (unique) points x; € X, and any choice of numbers c; € R. (We assume K(x;,x,;) > 0, 
so that we can only achieve equality in the above equation if c; = 0 for all 7.) 
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Another way to understand this condition is the following. Given a set of N datapoints, let us 
define the Gram matrix as the following N x N similarity matrix: 


K(z1, £1) «+: K(a1, £N) 
K= (17.2) 
K(an,a@1) > K(zN,£N) 


We say that K is a Mercer kernel iff the Gram matrix is positive definite for any set of (distinct) 
inputs {a;}%,. 

The most widely used kernel for real-valued inputs is the squared exponential kernel (SE), 
also called the exponentiated quadratic kernel (EQ), Gaussian kernel, or RBF kernel. It is 
defined by 


— wel | (2 
K(x, x’) = exp (e) (17.3) 


Here £ corresponds to the length scale of the kernel, i.e., the distance over which we expect differences 
to matter. This is known as the bandwidth parameter. The RBF kernel measures similarity between 
two vectors in R? using (scaled) Euclidean distance. In Section 17.1.2, we will discuss several other 
kinds of kernel. 

In Section 17.2, we show how to use kernels to define priors and posteriors over functions. The 
basic idea is this: if K(x, x’) is large, meaning the inputs are similar, then we expect the output of 
the function to be similar as well, so f(a) ~ f(x’). More precisely, information we learn about f(x) 
will help us predict f(a’) for all a’ which are correlated with x, and hence for which K(a, x’) is large. 

In Section 17.3, we show how to use kernels to generalize from Euclidean distance to a more general 
notion of distance, so that we can use geometric methods such as linear discriminant analysis in an 
implicit feature space instead of input space. 


17.1.1 Mercer’s theorem 


Recall from Section 7.4 that any positive definite matrix K can be represented using an eigendecom- 
position of the form K = U' AU, where A is a diagonal matrix of eigenvalues \; > 0, and U is a 
matrix containing the eigenvectors. Now consider element (i, j) of K: 


ki; = (A2Ux)' (A? U.) (17.4) 
where U.; is the i’th column of U. If we define 6(a;) = A2U,, then we can write 


kij = (ai) P(@;) = D> bm(wi) om (ws) (17.5) 


Thus we see that the entries in the kernel matrix can be computed by performing an inner product 
of some feature vectors that are implicitly defined by the eigenvectors of the kernel matrix. This 
idea can be generalized to apply to kernel functions, not just kernel matrices; this result is known as 
Mercer’s theorem. 

For example, consider the quadratic kernel K(x, x’) = (x,2’)”. In 2d, we have 


K(x, ax’) = (£121 + z281) = £? (a)? + 2a 200! r, + 23(25)" (17.6) 
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Ayndyno 


Figure 17.1: Function samples from a GP with an ARD kernel. (a) €: = 2 = 1. Both dimensions contribute 
to the response. (b) Lı = 1, £2 = 5. The second dimension is essentially ignored. Adapted from Figure 5.1 of 
[RW06]. Generated by gprDemoArd.ipynb. 


We can write this as K(x, a’) = o(x)'o(a) if we define $(x1, £2) = [x?, V2r122, £2] € R3. So we 
embed the 2d inputs x into a 3d feature space (a). 

Now consider the RBF kernel. In this case, the corresponding feature representation is infinite 
dimensional (see Section 17.2.9.3 for details). However, by working with kernel functions, we can 
avoid having to deal with infinite dimensional vectors. 


17.1.2 Some popular Mercer kernels 


In the sections below, we describe some popular Mercer kernels. More details can be found at [Wil14] 
and https://www.cs.toronto.edu/~duvenaud/cookbook/. 


17.1.2.1 Stationary kernels for real-valued vectors 


For real-valued inputs, ¥ = R?, it is common to use stationary kernels, which are functions of 
the form K(a, x’) = K(||a — æ'||); thus the value only depends on the elementwise difference between 
the inputs. The RBF kernel is a stationary kernel. We give some other examples below. 


ARD kernel 


We can generalize the RBF kernel by replacing Euclidean distance with Mahalanobis distance, as 
follows: 


K(r) = o° exp (-5 Tan) (17.7) 
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matern_kernel_5 matern_kernel_1 


(a) (b) 


Figure 17.2: Functions sampled from a GP with a Matern kernel. (a) v = 5/2. (b) v = 1/2. Generated by 
gpKernelPlot.ipynb. 


where r = æ — a’. If X is diagonal, this can be written as 


D D 
1 1 
K(r;,07) = 07 exp (-4 J 7) = II K(ra; la, 07/%) (17.8) 
d=1 d=1 
where 
K(r;£,77) = 7? exp hag? (17.9) 
a 28 Í 


We can interpret g? as the overall variance, and £4 as defining the characteristic length scale of 
dimension d. If d is an irrelevant input dimension, we can set £4 = co, so the corresponding dimension 
will be ignored. This is known as automatic relevancy determination or ARD (Section 11.7.7). 
Hence the corresponding kernel is called the ARD kernel. See Figure 17.1 for an illustration of 
some 2d functions sampled from a GP using this prior. 


Matern kernels 


The SE kernel gives rise to functions that are infinitely differentiable, and therefore are very smooth. 

For many applications, it is better to use the Matern kernel, which gives rise to “rougher” functions, 

which can better model local “wiggles” without having to make the overall length scale very small. 
The Matern kernel has the following form: 


Kinng = Fe ( zz) x, ( zr) (17.10) 


where K, is a modified Bessel function and £ is the length scale. Functions sampled from this GP 
are k-times differentiable iff v > k. As v > oo, this approaches the SE kernel. 
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(a) Periodic kernel. (b) Cosine kernel. 


Figure 17.3: Functions sampled from a GP using various stationary periodic kernels. Generated by gpKer- 
nelPlot.ipynb. 


For values v € {5, 3, 3}, the function simplifies as follows: 


,£) = exp(—5) (17.11) 


3r 3r 
(: + Z) exp (5) (17.12) 
5 5r 5r? 5r 
K(r; 3) = (: te y + sz) exp (5) (17.13) 


The value v = i corresponds to the Ornstein-Uhlenbeck process, which describes the velocity 


of a particle undergoing Brownian motion. The corresponding function is continuous but not 
differentiable, and hence is very “jagged”. See Figure 17.2b for an illustration. 


Periodic kernels 


The periodic kernel captures repeating structure, and has the form 


2 
Kper(r; L, p) = exp (-z sin? (r7) (17.14) 


where p is the period. See Figure 17.3a for an illustration. 
A related kernel is the cosine kernel: 


K(r; p) = cos (=) (17.15) 


See Figure 17.3b for an illustration. 
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17.1.2.2 Making new kernels from old 


Given two valid kernels K: (æ, x’) and K2(a, a’), we can create a new kernel using any of the following 
methods: 


K(a,x') = cK, (æ, x'), for any constant c > 0 (17.16) 
K(x, ax’) = f(£)Kı(x, x) f(x), for any function f (17.17) 
K(x, x’) = q(Kı (x, x'")) for any function polynomial q with nonneg. coef. (17.18) 
K(a, x’) = exp(K1(a, 2’)) (17.19) 
K(a, a’) = £" Aa’, for any psd matrix A (17.20) 


For example, suppose we start with the linear kernel K(a, a’) = 2'a’. We know this is a valid 
Mercer kernel, since the corresponding Gram matrix is just the (scaled) covariance matrix of the 
data. From the above rules, we can see that the polynomial kernel K(x, æ’) = (#'x’)™ is a valid 
Mercer kernel. This contains all monomials of order M. For example, if M = 2 and the inputs are 
2d, we have 


(a! ae’)? = (aya, + woah)? = (x121)? + (£282)? + 2(21 24) (xox5) (17.21) 


We can generalize this to contain all terms up to degree M by using the kernel K(a, x’) = (a'a’ +c)” 
For example, if M = 2 and the inputs are 2d, we have 


(at ae! + 1)? = (aya)? + (a2) (xox) + (z121) 
+ (£222)(£121) + (z223) + (oar) 
+ (x12) + (rer) +1 (17.22) 


We can also use the above rules to establish that the Gaussian kernel is a valid kernel. To see this, 
note that 


|æ — x! ||? = ala + (x)! x — 2x" x' (17.23) 
and hence 
K(a, x’) = exp(—||a — a" ||?/207) = exp(—a' &/207) exp(a' ax’ /o?) exp(—(a’)'a’ /207) (17.24) 


is a valid kernel. 


17.1.2.3 Combining kernels by addition and multiplication 
We can also combine kernels using addition or multiplication: 
K(a, 2’) = Ki (x, x") + Klx, 2’) (17.25) 
K(@, 2") = Ki(a,a’) x Kalz, 2’) (17.26) 
Multiplying two positive-definite kernels together always results in another positive definite kernel. 
This is a way to get a conjunction of the individual properties of each kernel, as illustrated in 
Figure 17.4. 
In addition, adding two positive-definite kernels together always results in another positive definite 


kernel. This is a way to get a disjunction of the individual properties of each kernel, as illustrated in 
Figure 17.5. 
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Lin x Lin SE x Per Lin x SE Lin x Per 
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quadratic functions locally periodic increasing variation growing amplitude 


Figure 17.4: Examples of 1d structures obtained by multiplying elementary kernels. Top row shows K(x, x’ = 1). 
Bottom row shows some functions sampled from GP(f|0,K). From Figure 2.2 of [Duv14]. Used with kind 
permission of David Duvenaud. 


Lin + Per SE + Per SE + Lin SECo”8) + gElshort) 
o 
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periodic plus trend periodic dius noise linear plus variation slow & fast variation 


Figure 17.5: Examples of 1d structures obtained by adding elementary kernels. Here SES") and SECO») 
are two SE kernels with different length scales. From Figure 2.4 of [Duv14]. Used with kind permission of 
David Duvenaud. 


17.1.2.4 Kernels for structured inputs 


Kernels are particularly useful when the inputs are structured objects, such as strings and graphs, 
since it is often hard to “featurize” variable-sized inputs. For example, we can define a string kernel 
which compares strings in terms of the number of n-grams they have in common [Lod-++02; BC17]. 

We can also define kernels on graphs [KJM19]. For example, the random walk kernel conceptually 
performs random walks on two graphs simultaneously, and then counts the number of paths that 
were produced by both walks. This can be computed efficiently as discussed in [Vis+10]. For more 
details on graph kernels, see [KJM19]. 

For a review of kernels on structured objects, see e.g., [Gar03]. 
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Figure 17.6: A Gaussian process for 2 training points, xı and x2, and 1 testing point, x., represented as 
a graphical model representing p(y, fx|X) = N(fx|m(X), K(X)) IL; pil fi). The hidden nodes fı = f(x) 
represent the value of the function at each of the data points. These hidden nodes are fully interconnected 
by undirected edges, forming a Gaussian graphical model; the edge strengths represent the covariance terms 
Xij = K(x, xj). If the test point x. is similar to the training points xı and a2, then the value of the hidden 
function f. will be similar to fı and f2, and hence the predicted output y, will be similar to the training 
values yi and y2. 


17.2 Gaussian processes 


In this section, we discuss Gaussian processes, which is a way to define distributions over functions 
of the form f : ¥ — R, where ¥ is any domain. The key assumption is that the function values at a set 
of M > 0 inputs, f = [f(x1),..., f(æm)], is jointly Gaussian, with mean (u = m(x1),..., m(æm)) 
and covariance X;j = K(a;,#;), where m is a mean function and K is a positive definite (Mercer) 
kernel. Since we assume this holds for any M > 0, this includes the case where M = N +1, 
containing N training points x, and 1 test point x. Thus we can infer f(x.) from knowledge of 
f(xi),.--,f(a@n) by manipulating the joint Gaussian distribution p(f(x1),...,f(a@w), f(@)), as we 
explain below. We can also extend this to work with the case where we observe noisy functions of 
f(a), such as in regression or classification problems. 


17.2.1 Noise-free observations 


Suppose we observe a training set D = { (£n, Yn) : n = 1 : N}, where yn = f(a») is the noise-free 
observation of the function evaluated at æn. If we ask the GP to predict f(a) for a value of æ that it 
has already seen, we want the GP to return the answer f(x) with no uncertainty. In other words, it 
should act as an interpolator of the training data. 

Now we consider the case of predicting the outputs for new inputs that may not be in D. Specifically, 
given a test set X, of size N, x D, we want to predict the function outputs fs = [f (£41), -, f(@,Nn,)]- 
By definition of the GP, the joint distribution p(fx, f.|X,X.) has the following form 


(F)~¥ (Ge) Ge xe) aran 
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Figure 17.7: (a) some functions sampled from a GP prior with squared exponential kernel. (b-d) : some 
samples from a GP posterior, after conditioning on 1,2, and 4 noise-free observations. The shaded area 
represents E [f (x)| + 2std[f(a)]. Adapted from Figure 2.2 of [RWO06]. Generated by gprDemoNoiseFree.ipynb. 


where uy = [m(a1),...,m(xn)], Ha = [M(xjZ),...,m(a#y,)], Kx,x = K(X, X) is Nx N, Kx, = 
K(X, X.) is N x N,, and K,.. = K(X, X.) is N, x N,. See Figure 17.6 for an illustration. By the 
standard rules for conditioning Gaussians (Section 3.2.3), the posterior has the following form 


PFX D) = N (Falha Be) (17.28) 
H, = m(Xx) + Kx Ky x(fx — m(X)) (17.29) 
D, = Ky. — Ky. Kxix Kx, (17.30) 


This process is illustrated in Figure 17.7. On the left we show some samples from the prior, p(f), 
where we use an RBF kernel (Section 17.1) and a zero mean function. On the right, we show samples 
from the posterior, p(f|D). We see that the model perfectly interpolates the training data, and that 
the predictive uncertainty increases as we move further away from the observed data. 


17.2.2 Noisy observations 


Now let us consider the case where what we observe is a noisy version of the underlying function, 
Yn = f(@n) + €En, where en ~ N(0, 7). In this case, the model is not required to interpolate the data, 
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but it must come “close” to the observed data. The covariance of the observed noisy responses is 
Cov [y;, yj] = Cov [fi, fj] + Cov [e;, ej] = K(x, £5) + on bij (17.31) 
where 6;; = I (i = j). In other words 
Cov [y|X] = Kx,x + oily àK; (17.32) 


The joint density of the observed data and the latent, noise-free function on the test points is given 
by 


(r(e. e) ~ 


Hence the posterior predictive density at a set of test points X, is 


PFD, Xx) = N (Fa lMapxs Zx) (17.34) 
Hax = H, +Ky.K3'(y- ux) (17.35) 
Zax = Kix — Ky K3 Kx. (17.36) 


In the case of a single test input, this simplifies as follows 
POF |D, £4) = N (fam, + BK G (y — Hx), kus — k, Kz ks) (17.37) 


where k, = [K (a, £1), ...,K(£4,£N)] and ky. = K(£4, £4). If the mean function is zero, we can 
write the posterior mean as follows: 


N 
lax = ki (K3'y) Ê ksa = X ` K(®s, anon (17.38) 


n=1 


This is identical to the predictions from kernel ridge regression in Equation (17.108). 


17.2.3 Comparison to kernel regression 


In Section 16.3.5, we discussed kernel regression, which is a generative approach to regression in 
which we approximate p(y, x) using kernel density estimation. In particular, Equation (16.39) gives 
us 


SN Kn(a—an)yn S 
i lylæ, D| = => = nWn (ax 17.39 
[y|a, D] =" Ki@a ae) È YnWn (2) (17.39) 


Kp(a@ — £n) 
N 
Dni Kp (x = En’) 
This is very similar to Equation (17.38). However, there are a few important differences. Firstly, in 


a GP, we use a positive definite (Mercer) kernel instead of a density kernel; Mercer kernels can be 
defined on structured objects, such as strings and graphs, which is harder to do for density kernels. 


wn (x) = (17.40) 
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Second, a GP is an interpolator (at least when o? = 0), so E [yļæn, D] = yn. By contrast, kernel 
regression is not an interpolator (although it can be made into one by iteratively fitting the residuals, 
as in [KJ16]). Third, a GP is a Bayesian method, which means we can estimate hyperparameters 
(of the kernel) by maximizing the marginal likelihood; by contrast, in kernel regression we must use 
cross-validation to estimate the kernel parameters, such as the bandwidth. Fourth, computing the 
weights Wn for kernel regression takes O(N) time, where N = |D|, whereas computing the weights 
Qn for GP regression takes O(N?) time (although there are approximation methods that can reduce 
this to O(NM7), as we discuss in Section 17.2.9). 


17.2.4 Weight space vs function space 


In this section, we show how Bayesian linear regression is a special case of a GP. 

Consider the linear regression model y = f(a) + €, where f(a) = w'o(a) and e ~ N (0, 07). If we 
use a Gaussian prior p(w) = N (w|0, £u), then the posterior is as follows (see Section 11.7.2 for the 
derivation): 


p(wID) = N(w| A718? y, A) (17.41) 
y 


where ® is the N x D design matrix, and 
A=0,7°8'S+57) (17.42) 


The posterior predictive distribution for f, = f(a.) is therefore 


pA '®ly, dA '?,) (17.43) 


* * 


1 
pfa |D, £4) = NCl 
y 
where @, = p(x). This views the problem of inference and prediction in weight space. 
We now show that this is equivalent to the predictions made by a GP using a kernel of the form 
K(x, x’) = (£) Eu p(x’). To see this, let K = ®D,,8', k, = ®D,,¢,, and kss = Pl Eup, Using 
this notation, and the matrix inversion lemma, we can rewrite Equation (17.43) as follows 


P( f(D, £4) = N (fal Max, Bapx) (17.44) 
Hajx = Zw (K+ 07I) ty = ki Key (17.45) 
Eux = Euh, — $L Eu' (K + 0271) BEL, = ke — ki KG ky (17.46) 


which matches the results in Equation (17.37), assuming m(æ) = 0. (Non-zero mean can be captured 
by adding a constant feature with value 1 to ọ(æ).) 

Thus we can derive a GP from Bayesian linear regression. Note, however, that linear regression 
assumes @(2) is a finite length vector, whereas a GP allows us to work directly in terms of kernels, 
which may correspond to infinite length feature vectors (see Section 17.1.1). That is, a GP works in 
function space. 


17.2.5 Numerical issues 


In this section, we discuss computational and numerical issues which arise when implementing the 
above equations. For notational simplicity, we assume the prior mean is zero, m(x) = 0. 
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Figure 17.8: Some 1d GPs with SE kernels but different hyper-parameters fit to 20 noisy observations. The 
hyper-parameters (£,07,0,) are as follows: (a) (1,1,0.1) (b) (3.0, 1.16, 0.89). Adapted from Figure 2.5 of 
[RW06]. Generated by gporDemoChangeHparams.ipynb. 


The posterior predictive mean is given by u, = k! K7 !y. For reasons of numerical stability, it is 
unwise to directly invert K,. A more robust alternative is to compute a Cholesky decomposition, 
K, = LL', which takes O(N?) time. Then we compute a = L' \ (L \ y), where we have used the 
backslash operator to represent backsubstitution (Section 7.7.1). Given this, we can compute the 
posterior mean for each test case in O(N) time using 


us =k KZ ty = k! LT! (L`ty) = kla (17.47) 
We can compute the variance in O(N?) time for each test case using 
o? = kee — REL -TL ky = kys — v'v (17.48) 


where v = L \ ky. 
Finally, the log marginal likelihood (needed for kernel learning, Section 17.2.6) can be computed 
using 


N 
1 N 
log p(y|X) = —sy"a — } log Lnn — = log(2n) (17.49) 
n=1 


17.2.6 Estimating the kernel 


Most kernels have some free parameters, which can have a large effect on the predictions from the 
model. For example, suppose we are performing 1d regression using a GP with an RBF kernel of the 
form 


1 
242 


Here £ is the horizontal scale over which the function changes, oF controls the vertical scale of the 


2 
y’ 


K (£p, £q) = oF exp(—=3 (£p — z4)°) (17.50) 


function. We assume observation noise with variance o. 
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We sampled 20 observations from an MVN with a covariance given by © = K(a;,2;) for a grid of 
points {x;}, and added observation noise of value o,. We then fit this data using a GP with the 
same kernel, but with a range of hyperparmameters. Figure 17.8 illustrates the effects of changing 
these parameters. In Figure 17.8(a), we use (€,0/, 7y) = (1,1,0.1), and the result is a good fit. In 
Figure 17.8(b), we increase the length scale to £ = 3; now the function looks overly smooth. 


17.2.6.1 Empirical Bayes 


To estimate the kernel parameters 0 (sometimes called hyperparameters), we could use exhaustive 
search over a discrete grid of values, with validation loss as an objective, but this can be quite 
slow. (This is the approach used by nonprobabilistic methods, such as SVMs (Section 17.3) to tune 
kernels.) Here we consider an empirical Bayes approach (Section 4.6.5.3), which will allow us to use 
gradient-based optimization methods, which are much faster. In particular, we will maximize the 
marginal likelihood 


p(ylX, 8) = J v(ylf, X) FIX, O)df (17.51) 


(The reason it is called the marginal likelihood, rather than just likelihood, is because we have 
marginalized out the latent Gaussian vector f.) 

For notational simplicity, we assume the mean function is 0. Since p(f|X) = N(f|0,K), and 
p(yl|f) = T N unl fos a2), the marginal likelihood is given by 


ee N 
log p(y|X, 0) = log N(yl0,K,) = —5y"Kz'y — 5 log|K,| — = log(2m) (17.52) 


where the dependence of Ko = Kx,x + o31 N on @ is implicit. The first term is a data fit term, the 
second term is a model complexity term, and the third term is just a constant. To understand the 
tradeoff between the first two terms, consider a SE kernel in 1D, as we vary the length scale £ and 
hold oF fixed. For short length scales, the fit will be good, so y! Kz ty will be small. However, the 
model complexity will be high: K will be almost diagonal, (as in Figure 13.22, top right), since most 
points will not be considered “near” any others, so the log|K,| term will be large. For long length 
scales, the fit will be poor but the model complexity will be low: K will be almost all 1’s, (as in 
Figure 13.22, bottom right), so log |K,| will be small. 
We now discuss how to maximize the marginal likelihood. One can show that 


a lll pity OR ng... l ppt ORs 
gg; EPIX, 8) =3Y K; 0; Ky y- 5tr(K, 0; ) (17.53) 
aK 
aS T -1 o 
= -tr ((aa K; e) (17.54) 


where a = K} ty. It takes O(N?) time to compute K7!1, and then O(N?) time per hyper-parameter 
to compute the gradient. 
The form of Un depends on the form of the kernel, and which parameter we are taking derivatives 
a 


with respect to. Often we have constraints on the hyper-parameters, such as g > 0. In this case, we 
can define 0 = log(a7), and then use the chain rule. 
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Figure 17.9: Illustration of local minima in the marginal likelihood surface. (a) We plot the log marginal 
likelihood vs kernel length scale £ and observation noise oy, for fixed signal level of = 1, using the 7 
data points shown in panels b and c. (b) The function corresponding to the lower left local minimum, 
(€,0y) © (1,0.2). This is quite “wiggly” and has low noise. (c) The function corresponding to the top right 
local minimum, (£,oy) ~ (10,0.8). This is quite smooth and has high noise. The data was generated using 
(L,of, oy) = (1,1,0.1). Adapted from Figure 5.5 of [RW06]. Generated by gpr_ demo_ marglik.ipynb. 


Given an expression for the log marginal likelihood and its derivative, we can estimate the kernel 
parameters using any standard gradient-based optimizer. However, since the objective is not convex, 
local minima can be a problem, as we illustrate below, so we may need to use multiple restarts. 

As an example, consider the RBF in Equation (17.50) with oF = 1. In Figure 17.9(a), we plot 
log p(y|X, £, o7) (where X and y are the 7 data points shown in panels b and c) as we vary £ and 
Ge. The two local optima are indicated by +. The bottom left optimum corresponds to a low-noise, 
short-length scale solution (shown in panel b). The top right optimum corresponds to a high-noise, 
long-length scale solution (shown in panel c). With only 7 data points, there is not enough evidence 
to confidently decide which is more reasonable, although the more complex model (panel b) has a 
marginal likelihood that is about 60% higher than the simpler model (panel c). With more data, the 
more complex model would become even more preferred. 

Figure 17.9 illustrates some other interesting (and typical) features. The region where a x1 
(top of panel a) corresponds to the case where the noise is very high; in this regime, the marginal 
likelihood is insensitive to the length scale (indicated by the horizontal contours), since all the data 
is explained as noise. The region where £ ~ 0.5 (left hand side of panel a) corresponds to the case 
where the length scale is very short; in this regime, the marginal likelihood is insensitive to the noise 
level (indicated by the vertical contours), since the data is perfectly interpolated. Neither of these 
regions would be chosen by a good optimizer. 


17.2.6.2 Bayesian inference 


When we have a small number of datapoints (e.g., when using GPs for Bayesian optimization), using 
a point estimate of the kernel parameters can give poor results [Bulll; WF 14]. In such cases, we may 
wish to approximate the posterior over the kernel parameters. Several methods can be used. For 
example, [MA10] shows how to use slice sampling, [Hen+15] shows how to use Hamiltonian Monte 
Carlo, and [BBV11] shows how to use sequential Monte Carlo. 
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(a) (b) 


Figure 17.10: GP classifier for a binary classification problem on Iris flowers (Setosa vs Versicolor) using 
a single input feature (sepal length). The fat vertical line is the credible interval for the decision boundary. 
(a) SE kernel. (b) SE plus linear kernel. Adapted from Figures 7.11-7.12 of [Mar18]. Generated by 
gp_classify_iris_1d_pymc3.ipynb. 


17.2.7 GPs for classification 


So far, we have focused on GPs for regression using Gaussian likelihoods. In this case, the posterior 
is also a GP, and all computation can be performed analytically. However, if the likelihood is 
non-Gaussian, such as the Bernoulli likelihood for binary classification, we can no longer compute 
the posterior exactly. 

There are various approximations we can make, some of which we discuss in the sequel to this 
book, [Mur23]. In this section, we use the Hamiltonian Monte Carlo method (Section 4.6.8.4), both 
for the latent Gaussian function f as well as the kernel hyperparameters 0. The basic idea is to 
specify the negative log joint 

N 
—E(f,0) = log p(f, 0|X, y) = log N(f|0, K(X, X)) + X` log Ber(yn|fn(an)) + log p(8) (17.55) 


n=1 


We then use autograd to compute V f€(f,@) and VeE(f,@), and use these gradients as inputs to a 
Gaussian proposal distribution. 

Let us consider a 1d example from [Mar18]. This is similar to the Bayesian logistic regression 
example from Figure 4.20, where the goal is to classify iris flowers as being Setosa or Versicolor, 
Yn E {0,1}, given information about the sepal length, zn. We will use an SE kernel with length scale 
£. We put a Ga(2, 0.5) prior on £. 

Figure 17.10a shows the results using the SE kernel. This is similar to the results of linear logistic 
regression (see Figure 4.20), except that at the edges (away from the data), the probability curves 
towards 0.5. This is because the prior mean function is m(x) = 0, and o(0) = 0.5. We can eliminate 
this artefact by using a more flexible kernel, which encodes the prior knowledge that we expect the 
output to be monotonically increasing or decreasing in the input. We can do this using a linear 
kernel, 


K(a,2') = (x —c)(a’ —c) (17.56) 
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Figure 17.11: (a) Fictitious “space flu” binary classification problem. (b) Fit from a GP with SE kernel. 
Adapted from Figures 7.13-7.14 of [Mar18]. Generated by gp_classify_spaceflu_1d_pymc3.ipynb. 


We can scale and add this to the SE kernel to get 


oe 


3 (17.57) 


K(x, x) = T(x — c)(x' — c) + exp |- 
The results are shown in Figure 17.10b, and look more reasonable. 

One might wonder why we bothered to use a GP, when the results are no better than a simple 
linear logistic regression model. The reason is that the GP is much more flexible, and makes fewer a 
priori assumptions, beyond smoothness. For example, suppose the data looked like Figure 17.11a. 
In this case, a linear logistic regression model could not fit the data. We could in principle use a 
neural network, but it may not work well since we only have 60 data points. However, GPs are well 
designed to handle the small sample setting. In Figure 17.11b, we show the results of fitting a GP 
with an SE kernel to this data. The results look reasonable. 


17.2.8 Connections with deep learning 


It turns out that there are many interesting connections and similarities between GPs and deep 
neural networks. For example, one can show that a neural network with a single, infinitely wide 
layer of RBF units is equivalent to a GP with an RBF kernel. (This follows from the fact that the 
RBF kernel can be expressed as the inner product of an infinite number of features.) In fact, many 
kinds of DNNs (in the infinite limit) can be converted to an equivalent GP using a specific kind of 
kernel known as the neural tangent kernel [JGH18]. See the sequel to this book, [Mur23], for 
details. 


17.2.9 Scaling GPs to large datasets 


The main disadvantage of GPs (and other kernel methods, such as SVMs, which we discuss in 
Section 17.3) is that inverting the N x N kernel matrix takes O(N?) time, making the method too 
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slow for big datasets. Many different approximate schemes have been proposed to speedup GPs (see 
e.g., [Liu+18a] for a review). In this section, we briefly mention some of them. For more details, see 
the sequel to this book, [Mur23]. 


17.2.9.1 Sparse (inducing-point) approximations 


A simple approach to speeding up GP inference is to use less data. A better approach is to try to 
“summarize” the N training points X into M < N inducing points or pseudo inputs Z. This 
lets us replace p(f|fx) with p(f|fz), where fx = {f(x) : @ € Zy is the vector of observed function 
values at the training points, and fz = {f(x) : x € Z} is the vector of estimated function values at 
the inducing points. By optimizing (Z, fz) we can learn to “compress” the training data (X, fx) 
into a “bottleneck” (Z, fz), thus speeding up computation from O(N*) to O(M?). This is called a 
sparse GP. This whole process can be made rigorous using the framework of variational inference. 
For details, see the sequel to this book, [Mur23]. 


17.2.9.2 Exploiting parallelization and kernel matrix structure 


It takes O(N?) time to compute the Cholesky decomposition of Kx, which is needed to solve 
the linear system K,a = y and to compute |Kx,x|, where Kz = Kx,x + o7Iy. An alternative to 
Cholesky decomposition is to use linear algebra methods, often called Krylov subspace methods, 
which are based just on matrix vector multiplication or MVM. These approaches are often 
much faster, since they can naturally exploit structure in the kernel matrix. Moreover, even if the 
kernel matrix does not have special structure, matrix multiplies are trivial to parallelize, and can 
thus be greatly accelerated by GPUs, unlike Cholesky based methods which are largely sequential. 
This is the basis of the popular GPyTorch package [Gar+18]. For more details, see the sequel to 
this book, [Mur23]. 


17.2.9.3 Random feature approximation 


Although the power of kernels resides in the ability to avoid working with featurized representations 
of the inputs, such kernelized methods take O(N?) time, in order to invert the Gram matrix K. 
This can make it difficult to use such methods on large scale data. Fortunately, we can approximate 
the feature map for many (shift invariant) kernels using a randomly chosen finite set of M basis 
functions, thus reducing the cost to O(NM + M?). We briefly discuss this idea below. For more 
details, see e.g., [Liu-+20]. 


Random features for RBF kernel 


We will focus on the case of the Gaussian RBF kernel. One can show that 


K(x, 0!) = ple) ola’) (17.58) 

where the (real-valued) feature vector is given by 
p(x) = zplnlwia), ..., Sin(Wpa), cos(w! æ), ..., cos(wpæ))] (17.59) 
= [sin(Qa), cos(Qzx)| (17.60) 


VT 
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where T = M/2, and Q € R?™*? is a random Gaussian matrix, where the entries are sampled iid 
from M (0, 1/07), where ø is the kernel bandwidth. The bias of the approximation decreases as we 
increase M. In practice, we use a finite M, and compute a single sample Monte Carlo approximation 
to the expectation by drawing a single random matrix. The features in Equation (17.60) are called 
random Fourier features (RFF) [RR08] or “weighted sums of random kitchen sinks” [RR09]. 

We can also use positive random features, rather than trigonometric random features, which can be 
preferable in some applications, such as models which use attention (see Section 15.6.4). In particular, 
we can use 


(a) & eelt [(exp(wta), +++, (explwtyæ)] (17.61) 


where wm are sampled as before. For details, see [Cho+20b]. 

Regardless of whether we use trigonometric or positive features, we can obtain a lower variance 
estimate by ensuring that the rows of Z are random but orthogonal; these are called orthogonal 
random features. Such sampling can be conducted efficiently via Gram-Schmidt orthogonalization 
of the unstructured Gaussian matrices [Yu+ 16], or several approximations that are even faster (see 
[CRW17; Cho+19]). 


Fastfood approximation 


Unfortunately, storing the random matrix Q takes O(DM) space, and computing Qa takes O(DM) 
time, where D is the input dimensionality, and M is the number of random features. This can be 
prohibitive if M >> D, which it may need to be in order to get any benefits over using the original set 
of features. Fortunately, we can use the fast Hadamard transform to reduce the memory from 
O(MD) to O(M), and reduce the time from O(MD) to O(M log D). This approach has been called 
fastfood [LSS13], a reference to the original term “kitchen sinks”. 


Extreme learning machines 


We can use the random features approximation to the kernel to convert a GP into a linear model of 
the form 


f(a;0) = Wolz) = Wh(Za) (17.62) 


where h(a) = ./1/M|sin(a), cos(a)] for RBF kernels. This is equivalent to a one-layer MLP with 
random (and fixed) input-to-hidden weights. When M > N, this corresponds to an over-parameterized 
model, which can perfectly interpolate the training data. 

In [Cur+17], they apply this method to fit a logistic regression model of the form f(x;0) = 
W'h(Za) + b using SGD; they call the resulting method “McKernel”. We can also optimize Z as 
well as W, as discussed in [Alb+17], although now the problem is no longer convex. 

Alternatively, we can use M < N, but stack many such random nonlinear layers together, and just 
optimize the output weights. This has been called an extreme learning machine or ELM (see 
e.g., [Hual4]), although this work is controversial. 


1. The controversy has arisen because the inventor Guang-Bin Huang has been accused of not citing related prior 
work, such as the equivalent approach based on random feature approximations to kernels. For details, see https: 
//en.wikipedia. org/wiki/Extreme_learning_machine#Controversy. 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


17.3. Support vector machines (SVMs) 583 


Figure 17.12: Illustration of the large margin principle. Left: a separating hyper-plane with large margin. 
Right: a separating hyper-plane with small margin. 


17.3 Support vector machines (SVMs) 


In this section, we discuss a form of (non-probabilistic) predictors for classification and regression 
problems which have the form 


N 
f(x) = X aiK(@, wi) (17.63) 


By adding suitable constraints, we can ensure that many of the a; coefficients are 0, so that predictions 
at test time only depend on a subset of the training points, known as “support vectors”. Hence 
the resulting model is called a support vector machine or SVM. We give a brief summary below. 
More details, can be found in e.g., [VGS97; SS01]. 


17.3.1 Large margin classifiers 


Consider a binary classifier of the form h(a) = sign(f(a#)), where the decision boundary is given by 
the following linear function: 


f(a) = wx + wo (17.64) 


(In the SVM literature, it is common to assume the class labels are —1 and +1, rather than 0 and 1. 
To avoid confusion, we denote such target labels by % rather than y.) There may be many lines that 
separate the data. However, intuitively we would like to pick the one that has maximum margin, 
which is the distance of the closest point to the decision boundary, since this will give us the most 
robust solution. This idea is illustrated in Figure 17.12: the solution on the left has larger margin 
than the one on the right, so it will be less sensitive to perturbations of the data. 

How can we compute such a large margin classifier? First we need to derive an expression for 
the distance of a point to the decision boundary. Referring to Figure 17.13(a), we see that 


w 
z = 2, +r—— (17.65) 
||20]| 


where r is the distance of æ from the decision boundary whose normal vector is w, and æ] is the 
orthogonal projection of æ onto this boundary. 
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Figure 17.18: (a) Illustration of the geometry of a linear decision boundary in 2d. A point x is classified as 
belonging in decision region Ri if f(x) > 0, otherwise it belongs in decision region Ro; w is a vector which is 
perpendicular to the decision boundary. The term wo controls the distance of the decision boundary from the 
origin. a, is the orthogonal projection of x onto the boundary. The signed distance of x from the boundary 
is given by f(a)/||w||. Adapted from Figure 4.1 of [Bis06]. (b) Points with circles around them are support 
vectors, and have dual variables an > 0. In the soft margin case, we associate a slack variable En with each 
example. If0 < En < 1, the point is inside the margin, but on the correct side of the decision boundary. If 
En > 1, the point is on the wrong side of the boundary. Adapted from Figure 7.3 of [Bis06]. 


We would like to maximize r, so we need to express it as a function of w. First, note that 


w! wW 


f(a) = wz + wo = (w'a, +w) +r =(w'a, + wo) + r|jw|| (17.66) 


Ilw] 
Since 0 = f(a.) = w' a, + wo, we have f(x) = r||w|| and hence r = g. 
Since we want to ensure each point is on the correct side of the boundary, we also require 
f(&n)¥n > 0. We want to maximize the distance of the closest point, so our final objective becomes 
= ACH ] 17.67 
Das [Jw] Bad elw Sn + wo) eee 
Note that by rescaling the parameters using w —> kw and wo — kwo, we do not change the distance 
of any point to the boundary, since the k factor cancels out when we divide by ||w||. Therefore let 
us define the scale factor such that nfn = 1 for the point that is closest to the decision boundary. 
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Unscaled Scaled 


-2 -1 0 1 2 


Figure 17.14: Illustration of the benefits of scaling the input features before computing a max margin classifier. 
Adapted from Figure 5.2 of [Gér19]. Generated by sum_classifier_ feature_ scaling.ipynb. 


Hence we require jn fn > 1 for all n. Finally, note that maximizing 1/||w]|| is equivalent to minimizing 
||w||°. Thus we get the new objective 


1 
min lel? st. Jn(w'an+wo)>1n=1:N (17.68) 


W,Wo 


(The factor of 4 is added for convenience and doesn’t affect the optimal parameters.) The constraint 
says that we want all points to be on the correct side of the decision boundary with a margin of at 
least 1. 

Note that it is important to scale the input variables before using an SVM, otherwise the margin 
measures distance of a point to the boundary using all input dimensions equally. See Figure 17.14 
for an illustration. 


17.3.2 The dual problem 


The objective in Equation (17.68) is a standard quadratic programming problem (Section 8.5.4), 
since we have a quadratic objective subject to linear constraints. This has N + D + 1 variables 
subject to N constraints, and is known as a primal problem. 

In convex optimization, for every primal problem we can derive a dual problem. Let a € RY be 
the dual variables, corresponding to Lagrange multipliers that enforce the N inequality constraints. 
The generalized Lagrangian is given below (see Section 8.5.2 for relevant background information on 
constrained optimization): 


N 
1 
L(w, wo, a) = gww = X an (Gn(w'an + wo) — 1) (17.69) 


n=1 
To optimize this, we must find a stationary point that satisfies 


(w, Ho, @) = min max L(w, wo, œ) (17.70) 
w,wo a 
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We can do this by computing the partial derivatives wrt w and wo and setting to zero. We have 


N 
VwL(w, Wo, a) =w- 5 OnUnEn (17.71) 
n=1 
a N 
Fang £ (ws Wo &) = ~ Onn (17.72) 
and hence 
N 
td = X Antin®n (17.73) 
n=1 
N 
0= 5y Ânn (17.74) 
n=1 


N N N 
1 
Lt, to, a) = zð -X anjn En -X annwo +) an (17.75) 
n=1 n=1 n=1 
1 N 
_ te a aT 
= 5th - a (17.76) 


1 N N N 
~~ 5 me 2, aiajJiJjLj Lj + D Qn (17.77) 
n=1 


i=1 j=1 


This is called the dual form of the objective. We want to maximize this wrt œ subject to the 
constraints that DG Qnn = 0 and 0 < an for n=1: N. 

The above objective is a quadratic problem in N variables. Standard QP solvers take O(N?) time. 
However, specialized algorithms, which avoid the use of generic QP solvers, have been developed for 
this problem, such as the sequential minimal optimization or SMO algorithm [Pla98], which 
takes O(N) to O(N?) time. 

Since this is a convex objective, the solution must satisfy the KKT conditions (Section 8.5.2), 
which tell us that the following properties hold: 


an > 0 (17.78) 
Inf (#n) -—120 (17.79) 
On(Gnf(#n) — 1) =0 (17.80) 


Hence either an = 0 (in which case example n is ignored when computing wW) or the constraint 
In(w' a, + two) = 1 is active. This latter condition means that example n lies on the decision 
boundary; these points are known as the support vectors, as shown in Figure 17.13(b). We denote 
the set of support vectors by S. 
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To perform prediction, we use 


f(z; Ô, to) = Wx + tio = X annene + to (17.81) 
nes 
To solve for Wo we can use the fact that for any support vector, we have n f(x; Ù, wo) = 1. Multiplying 
both sides by ğn, and exploiting the fact that 72 = 1, we get wo = Jn — W' £n. In practice we get 
better results by averaging over all the support vectors to get 


1 2s 1 : 
aa] X (Gn — ep) = is] S Gin — YF Omiim® En) (17.82) 


nes nes mes 


17.3.3 Soft margin classifiers 


If the data is not linearly separable, there will be no feasible solution in which nfn > 1 for all n. 
We therefore introduce slack variables €,, > 0 and replace the hard constraints that Yn fn > 0 with 
the soft margin constraints that Y,f, > 1—&,. The new objective becomes 


N 
1 
min llw? +C Dén st. & >0, (srw two) >1-& (17.83) 
w,wo€ 2 


n=1 


where C > 0 is a hyper parameter controlling how many points we allow to violate the margin 
constraint. (If C = oo, we recover the unregularized, hard-margin classifier.) 
The corresponding Lagrangian for the soft margin classifier becomes 


N N N 
L(w, wo, a, €,) = sere + c> En — 5 On (Gn(w'an + wo) — 1+ ên) -— 5 Ln€n (17.84) 


n= n=1 n=1 


where a, > 0 and un > 0 are the Lagrange multipliers. Optimizing out w, wo and € gives the dual 
form 


LX 
L(a) = 2; a 5 Ds; > oie Hit} w; (17.85) 


This is identical to the hard margin case; however, the constraints are different. In particular, the 
KKT conditions imply 


0<an<C (17.86) 
N 
X ann = 0 (17.87) 
n=1 


If a, = 0, the point is ignored. If 0 < a, < C then én = 0, so the point lies on the margin. If a, = C, 
the point can lie inside the margin, and can either be correctly classified if én < 1, or misclassified if 
En > 1. See Figure 17.13(b) for an illustration. Hence }_, €, is an upper bound on the number of 
misclassified points. 
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As before, the bias term can be computed using 


tio = a XO (n J OmimEpEn) (17.88) 


nEM mES 


where M is the set of points having 0 < an < C. 
There is an alternative formulation of the soft margin SVM known as the v-SVM classifier 
[Sch+00]. This involves maximizing 


1 N N 
pss D de 1 Gij@] Ey (17.89) 


subject to the constraints that 


0< an <1/N (17.90) 
N 
Sein = 0 (17.91) 
M 
X anv (17.92) 


n=1 


This has the advantage that the parameter v, which replaces C, can be interpreted as an upper 
bound on the fraction of margin errors (points for which ¿n > 0), as well as a lower bound on the 
number of support vectors. 


17.3.4 The kernel trick 


So far we have converted the large margin binary classification problem into a dual problem in 
N unknowns (œ) which (in general) takes O(N®) time to solve, which can be slow. However, the 
principal benefit of the dual problem is that we can replace all inner product operations «' a’ with a 
call to a positive definite (Mercer) kernel function, K(a, a’). This is called the kernel trick. 


In particular, we can rewrite the prediction function in Equation (17.81) as follows: 


f(x) = wx + tio = X anGna)e + tio = Yo ann K (an, £) + Wo (17.93) 
nes nes 


We also need to kernelize the bias term. This can be done by kernelizing Equation (17.82) as follows: 
ze a 1 : a 
Wo = 1S] D3 Yi = DD Gj 9j Xj)! xi = isj DD Yi — 5 âj K (Tj, Ti) (17.94) 
tES jEs ics jes 


The kernel trick allows us to avoid having to deal with an explicit feature representation of our 
data, and allows us to easily apply classifiers to structured objects, such as strings and graphs. 
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Figure 17.15: Log-odds vs x for 8 different methods. Adapted from Figure 10 of [Tip01]. Used with kind 
permission of Mike Tipping. 


17.3.5 Converting SVM outputs into probabilities 


An SVM classifier produces a hard-labeling, g(a) = sign(f(x)). However, we often want a measure 
of confidence in our prediction. One heuristic approach is to interpret f(a) as the log-odds ratio, 


log A oae., We can then convert the output of an SVM to a probability using 
ply = 1|x, 0) = o(a f (æ) + b) (17.95) 


where a, b can be estimated by maximum likelihood on a separate validation set. (Using the training 
set to estimate a and b leads to severe overfitting.) This technique was first proposed in [Pla00], and 
is known as Platt scaling. 

However, the resulting probabilities are not particularly well calibrated, since there is nothing in 
the SVM training procedure that justifies interpreting f(x) as a log-odds ratio. To illustrate this, 
consider an example from [Tip01]. Suppose we have 1d data where p(x|y = 0) = Unif(0,1) and 
p(«l|y = 1) = Unif(0.5, 1.5). Since the class-conditional distributions overlap in the [0.5, 1] range, the 
log-odds of class 1 over class 0 should be zero in this region, and infinite outside this region. We 
sampled 1000 points from the model, and then fit a probabilistic kernel classifier (an RVM, described 
in Section 17.4.1) and an SVM with a Gaussian kernel of width 0.1. Both models can perfectly 
capture the decision boundary, and achieve a generalization error of 25%, which is Bayes optimal in 
this problem. The probabilistic output from the RVM is a good approximation to the true log-odds, 
but this is not the case for the SVM, as shown in Figure 17.15. 


17.3.6 Connection with logistic regression 


We have seen that data points that are on the correct side of the decision boundary have én = 0; for 
the others, we have £n = 1 — Jn f (£n). Therefore we can rewrite the objective in Equation (17.83) as 
follows: 


N 
L(w) = Ñ ninge(Gns f (@n))) + Alw]? (17.96) 


where A = (2C)~! and Cninge(y, n) is the hinge loss function defined by 


ena Uy n) = max(0, 1— yn) (17.97) 
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Not Cı Not Cy 


(a) (b) 


Figure 17.16: (a) The one-versus-rest approach. The green region is predicted to be both class 1 and class 2. 
(b) The one-versus-one approach. The label of the green region is ambiguous. Adapted from Figure 4.2 of 
[Bis06]. 


As we see from Figure 4.2, this is a convex, piecewise differentiable upper bound to the 0-1 loss, that 
has the shape of a partially open door hinge. 
By contrast, (penalized) logistic regression optimizes 


N 
L(w) = X tul, f(@n))) + Alw]? (17.98) 


where the log loss is given by 


lu(G,n) = — log p(y|n) = log(1 + e7") (17.99) 


This is also plotted in Figure 4.2. We see that it is similar to the hinge loss, but with two important 
differences. First the hinge loss is piecewise linear, so we cannot use regular gradient methods to 
optimize it. (We can, however, compute the subgradient at 77 = 1.) Second, the hinge loss has a 
region where it is strictly 0; this results in sparse estimates. 

We see that both functions are convex upper bounds on the 0-1 loss, which is given by 


f(y, 9) =1G AG) =1G9 <9) (17.100) 


These upper bounds are easier to optimize and can be viewed as surrogates for the 0-1 loss. See 
Section 4.3.2 for details. 


17.3.7 Maulti-class classification with SVMs 


SVMs are inherently a binary classifier. One way to convert them to a multi-class classification 
model is to train C binary classifiers, where the data from class c is treated as positive, and the 
data from all the other classes is treated as negative. We then use the rule ĝ(æx) = arg max, f(x) to 
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y=0.1,C=0.001 y=0.1,C =1000 


-15 -1.0 -0.5 0.0 


Figure 17.17: SVM classifier with RBF kernel with precision y and regularizer C applied to two moons data. 
Adapted from Figure 5.9 of [Gér19]. Generated by sum_classifier_ 2d.ipynb. 


predict the final label, where fe(æ) = log 7 Se is the score given by classifier c. This is known as 
the one-versus-the-rest approach (also called one-vs-all). 

Unfortunately, this approach has several problems. First, it can result in regions of input space 
which are ambiguously labeled. For example, the green region at the top of Figure 17.16(a) is 
predicted to be both class 2 and class 1. A second problem is that the magnitude of the fe’s scores 
are not calibrated with each other, so it is hard to compare them. Finally, each binary subproblem is 
likely to suffer from the class imbalance problem (Section 10.3.8.2). For example, suppose we have 10 
equally represented classes. When training fı, we will have 10% positive examples and 90% negative 
examples, which can hurt performance. 

Another approach is to use the one-versus-one or OVO approach, also called all pairs, in which 
we train C(C — 1)/2 classifiers to discriminate all pairs fe. We then classify a point into the class 
which has the highest number of votes. However, this can also result in ambiguities, as shown in 
Figure 17.16(b). Also, this requires fitting O(C”) models. 


17.3.8 How to choose the regularizer C 


SVMs require that you specify the kernel function and the parameter C. Typically C is chosen by 
cross-validation. Note, however, that C interacts quite strongly with the kernel parameters. For 
example, suppose we are using an RBF kernel with precision y = x If y is large, corresponding to 
narrow kernels, we may need heavy regularization, and hence small C. If y is small, a larger value of 
C should be used. So we see that y and C are tightly coupled, as illustrated in Figure 17.17. 


The authors of libsvm [HCL03] recommend using CV over a 2d grid with values C € {2~°,273,..., 21°} 
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Figure 17.18: (a) A cross validation estimate of the 0-1 error for an SVM classifier with RBF kernel with 
different precisions y = 1/(20°) and different regularizer A = 1/C, applied to a synthetic data set drawn from 
a mixture of 2 Gaussians. (b) A slice through this surface for y = 5 The red dotted line is the Bayes optimal 
error, computed using Bayes rule applied to the model used to generate the data. Adapted from Figure 12.6 of 
[HTF09]. Generated by sunCgammaDemo.ipynb. 


and y € {2715 2713, ...,23}. See Figure 17.18 which shows the CV estimate of the 0-1 risk as a 
function of C and y. 

To choose C efficiently, one can develop a path following algorithm in the spirit of lars (Sec- 
tion 11.4.4). The basic idea is to start with C small, so that the margin is wide, and hence all points 
are inside of it and have a; = 1. By slowly increasing C, a small set of points will move from inside 
the margin to outside, and their œ; values will change from 1 to 0, as they cease to be support vectors. 
When C is maximal, the margin becomes empty, and no support vectors remain. See [Has+04] for 
the details. 


17.3.9 Kernel ridge regression 


Recall the equation for ridge regression from Equation (11.55): 


Ûmap = (XTX + Ap) X'y = (X eng, + Ald) H(X JnEn) (17.101) 


Using the matrix inversion lemma (Section 7.3.3), we can rewrite the ridge estimate as follows 


w = X (XXT 4 Ay) ty = X £a (( L ean + AIN) ty)n (17.102) 


Let us define the following dual variables: 


a Ê (XX! + Aly) ty = (X whan +AIn) ty (17.103) 
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Then we can rewrite the primal variables as follows 
N 
w= XTa= So antn (17.104) 
n=1 


This tells us that the solution vector is just a linear sum of the N training vectors. When we plug 
this in at test time to compute the predictive mean, we get 


N 
f(a;w) = wz = 5 Ont) x (17.105) 
n=1 


We can then use the kernel trick to rewrite this as 


N 
f(x; w) = X anK (an, a) (17.106) 
n=1 
where 
a=(K+4+AIy)"'y (17.107) 


In other words, 
f(a;w) = k! (K + Iv) ty (17.108) 


where k = K(x, æx1),...,K(æ, æn )]. This is called kernel ridge regression. 
The trouble with the above approach is that the solution vector œ is not sparse, so predictions at 
test time will take O(N) time. We discuss a solution to this in Section 17.3.10. 


17.3.10 SVMs for regression 
Consider the following ¢2-regularized ERM problem: 


N 
J(w, A) =Allw||? + XO Kin, In) (17.109) 


n=1 


where n = w'an + wo. If we use the quadratic loss, L(y, 7) = (y — 9), where y, € R, we recover 
ridge regression (Section 11.3). If we then apply the kernel trick, we recover kernel ridge regression 
(Section 17.3.9). 

The problem with kernel ridge regression is that the solution depends on all N training points, 
which makes it computationally intractable. However, by changing the loss function, we can make 
the optimal set of basis function coefficients, a*, be sparse, as we show below. 

In particular, consider the following variant of the Huber loss function (Section 5.1.5.3) called the 
epsilon insensitive loss function: 


0 if |y—g| < € 


ly—gl—e otherwise (17.110) 


Luned 
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Figure 17.19: (a) Illustration of (2, Huber and e€-insensitive loss functions, where e = 1.5. Generated by 
huberLossPlot.ipynb. (b) Illustration of the e-tube used in SVM regression. Points above the tube have &* > 0 
and £; = 0. Points below the tube have £} = 0 and & > 0. Points inside the tube have E} = & = 0. 
Adapted from Figure 7.7 of [Bis06]. 


This means that any point lying inside an e-tube around the prediction is not penalized, as in 
Figure 17.19. 
The corresponding objective function is usually written in the following form 


N 
Te es o 
J = zllwl? +C >| Lelin In) (17.111) 


n=1 


where Ôn = f (an) = w' a, + wo and C = 1/) is a regularization constant. This objective is convex 
and unconstrained, but not differentiable, because of the absolute value function in the loss term. 
As in Section 11.4.9, where we discussed the lasso problem, there are several possible algorithms we 
could use. One popular approach is to formulate the problem as a constrained optimization problem. 
In particular, we introduce slack variables to represent the degree to which each point lies outside 
the tube: 


Ün < f(@n) + E+ Er (17.112) 
Ün > f(n) = €- En (17.113) 
Given this, we can rewrite the objective as follows: 
1 N 
T= e +C > (et + &) (17.114) 
n=1 


This is a quadratic function of w, and must be minimized subject to the linear constraints in 
Equations 17.112-17.113, as well as the positivity constraints €&* > 0 and €> > 0. This is a standard 
quadratic program in 2N + D + 1 variables. 
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By forming the Lagrangian and optimizing, as we did above, one can show that the optimal solution 
has the following form 


oa ae, (17.115) 


where @n > 0 are the dual variables. (See e.g., [SS02] for details.) Fortunately, the œ vector is sparse, 
meaning that many of its entries are equal to 0. This is because the loss doesn’t care about errors 
which are smaller than e. The degree of sparsity is controlled by C and e. 

The £n for which a, > 0 are called the support vectors; these are points for which the errors lie 
on or outside the e tube. These are the only training examples we need to keep for prediction at test 
time, since 


f(a) = ôo +s =o X anse (17.116) 


Finally, we can use the kernel trick to get 


f(w)=to+ X onX (an, z) (17.117) 


N:an >0 


This overall technique is called support vector machine regression or SVM regression for 
short, and was first proposed in [VGS97]. 

In Figure 17.20, we give an example where we use an RBF kernel with y = 1. When C is small, 
the model is heavily regularized; when C is large, the model is less regularized and can fit the data 
better. We also see that when € is small, the tube is smaller, so there are more support vectors. 


17.4 Sparse vector machines 


GPs are very flexible models, but incur an O(N) time cost at prediction time, which can be prohibitive. 
SVMs solve that problem by estimating a sparse weight vector. However, SVMs do not give calibrated 
probabilistic outputs. 

We can get the best of both worlds by using parametric models, where the feature vector is defined 
using basis functions centered on each of the training points, as follows: 


p(x) = |K(wv,x1),...,K(x,an)| (17.118) 


where XK is any similarity kernel, not necessarily a Mercer kernel. Equation (17.118) maps x € ¥ 
into d(a) € RN. We can plug this new feature vector into any discriminative model, such as logistic 
regression. Since we have D = N parameters, we need to use some kind of regularization, to prevent 
overfitting. If we fit such a model using 42 regularization (which we will call L2VM, for ¢2-vector 
machine), the result often has good predictive performance, but the weight vector w will be dense, 
and will depend on all N training points. A natural solution is to impose a sparsity-promoting 
prior on w, so that not all the exemplars need to be kept. We call such methods sparse vector 
machines. 
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C=100,€=0.1 C=0.01,e=0.1 
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Figure 17.20: Illustration of support vector regression. Adapted from Figure 5.11 of [Gér19]. Generated by 
sum_ regression_ 1d.ipynb. 


17.4.1 Relevance vector machines (RVMs) 


The simplest way to ensure w is sparse is to use @; regularization, as in Section 11.4. We call this 
L1VM or Laplace vector machine, since this approach is equivalent to using MAP estimation 
with a Laplace prior for w. 

However, sometimes £ regularization does not result in a sufficient level of sparsity for a given 
level of accuracy. An alternative approach is based on the use of ARD or automatic relevancy 
determination, which uses type II maximum likelihood (aka empirical Bayes) to estimate a sparse 
weight vector [Mac95; Nea96]. If we apply this technique to a feature vector defined in terms of 
kernels, as in Equation (17.118), we get a method called the relevance vector machine or RVM 
[Tip01; TF03]. 


17.4.2 Comparison of sparse and dense kernel methods 


In Figure 17.21, we compare L2VM, L1VM, RVM and an SVM using an RBF kernel on a binary 
classification problem in 2d. We use cross validation to pick C = 1/A for the SVM (see Section 17.3.8), 
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logregL2, nerr= 27 logregL1, nerr=20, nsupport=88 


(c) (a) 


Figure 17.21: Example of non-linear binary classification using an RBF kernel with bandwidth o = 0.3. (a) 
L2VM. (b) L1VM. (c) RVM. (d) SVM. Green circles denote the support vectors. Generated by kernelBina- 
ryClassifDemo.ipynb. 


and then use the same value of the regularizer for L2VM and L1VM. We see that all the methods 
give similar predictive performance. However, we see that the RVM is the sparsest model, so it will 
be the fastest at run time. 

In Figure 17.22, we compare L2VM, L1VM, RVM and an SVM using an RBF kernel on a 1d 
regression problem. Again, we see that predictions are quite similar, but RVM is the sparsest, then 
LIVM, then SVM. This is further illustrated in Figure 17.23. 

Beyond these small empirical examples, we provide a more general summary of the different 
methods in Table 17.1. The columns of this table have the following meaning: 


e Optimize w: a key question is whether the objective £(w) = — log p(D|w) — log p(w) is convex or 


not. L2VM, L1VM and SVMs have convex objectives. RVMs do not. GPs are Bayesian methods 
that integrate out the weights w. 
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Figure 17.22: Model fits for kernel based regression on the noisy sinc function using an RBF kernel with 
bandwidth o = 0.3. (a) L2VM with A = 0.5. (b) LIVM with A = 0.5. (c) RVM. (d) SVM regression with 
C =1/X. chosen by cross validation. Red circles denote the retained training exemplars. Generated by 
rum_regression_1d.ipynb. 


e Optimize kernel: all the methods require that we “tune” the kernel parameters, such as the 
bandwidth of the RBF kernel, as well as the level of regularization. For methods based on 
Gaussian priors, including L2VM, RVMs and GPs, we can use efficient gradient based optimizers 
to maximize the marginal likelihood. For SVMs and L1VMs, we must use cross validation, which 
is slower (see Section 17.3.8). 


e Sparse: LIVM, RVMs and SVMs are sparse kernel methods, in that they only use a subset of 
the training examples. GPs and L2VM are not sparse: they use all the training examples. The 
principle advantage of sparsity is that prediction at test time is usually faster. However, this 
usually results in overconfidence in the predictions. 


e Probabilistic: All the methods except for SVMs produce probabilistic output of the form p(y|x). 
SVMs produce a “confidence” value that can be converted to a probability, but such probabilities 
are usually very poorly calibrated (see Section 17.3.5). 


e Multiclass: All the methods except for SVMs naturally work in the multiclass setting, by using a 
categorical distribution instead of a Bernoulli. The SVM can be made into a multiclass classifier, 
but there are various difficulties with this approach, as discussed in Section 17.3.7. 
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linregl2 linregL1 


(c) 


Figure 17.23: Estimated coefficients for the models in Figure 17.22. Generated by rum_regression_ 1d.ipynb. 


Method Opt. w Opt. kernel Sparse Prob. Multiclass Non-Mercer Section 
SVM Convex CV Yes No Indirectly No Lid 
L2VM Convex EB No Yes Yes Yes 17.4.1 
LIVM Convex CV Yes Yes Yes Yes VAL 
RVM Not convex EB Yes Yes Yes Yes 17:41 
GP N/A EB No Yes Yes No IRAT 


Table 17.1: Comparison of various kernel based classifiers. EB = empirical Bayes, CV = cross validation. 
See text for details. 


e Mercer kernel: SVMs and GPs require that the kernel is positive definite; the other techniques do 
not, since the kernel function in Equation (17.118) can be an arbitrary function of two inputs. 


17.5 Exercises 


Exercise 17.1 [Fitting an SVM classifier by hand *] 


(Source: Jaakkola.) Consider a dataset with 2 points in 1d: x1 = 0 with label yı = —1 and x2 = V2 with 
label y2 = 1. Consider mapping each point to 3d using the feature vector (x) = [1, V2x,x7]7. (This is 
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equivalent to using a second order polynomial kernel.) The max margin classifier has the form 


min ||w||? s.t. (17.119) 
y(w p(s) + wo) > 1 (17.120) 
yo(w' o(w2) + wo) > 1 (17.121) 


a. Write down a vector that is parallel to the optimal vector w. Hint: recall from Figure 17.12(a) that w is 
perpendicular to the decision boundary between the two points in the 3d feature space. 


b. What is the value of the margin that is achieved by this w? Hint: recall that the margin is the distance 
from each support vector to the decision boundary. Hint 2: think about the geometry of 2 points in space, 
with a line separating one from the other. 


c. Solve for w, using the fact that the margin is equal to 1/||w|. 


d. Solve for wo using your value for w and Equations 17.119 to 17.121. Hint: the points will be on the 
decision boundary, so the inequalities will be tight. 


e. Write down the form of the discriminant function f(x) = wo + w(x) as an explicit function of z. 
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18.1 Classification and regression trees (CART) 


Classification and regression trees or CART models [BFO84], also called decision trees 
[Qui86; Qui93], are defined by recursively partitioning the input space, and defining a local model in 
each resulting region of input space. The overall model can be represented by a tree, with one leaf 
per region, as we explain below. 


18.1.1 Model definition 


We start by considering regression trees, where all inputs are real-valued. The tree consists of a set 
of nested decision rules. At each node i, the feature dimension d; of the input vector æ is compared 
to a threshold value t;, and the input is then passed down to the left or right branch, depending on 
whether it is above or below threshold. At the leaves of the tree, the model specifies the predicted 
output for any input that falls into that part of the input space. 

For example, consider the regression tree in Figure 18.1(a). The first node asks if x, is less than 
some threshold tı. If yes, we then ask if x2 is less than some other threshold tz. If yes, we enter the 
bottom left leaf node. This corresponds to the region of space defined by 


Ry = {x 2 1 < t1, £2 < t2} (18.1) 
We can associate this region with the predicted output, say y = 2. In a similar way, we can partition 


the entire input space into 5 regions using axis parallel splits, as shown in Figure 18.1(b).! 
Formally, a regression tree can be defined by 


J 
f(a;0) = L w,l (a € R;) (18.2) 


where Rj is the region specified by the j’th leaf node, w; is the predicted output for that node, 


N 
uy = Diawal (æn € R) n 
È; I (æn € R;) 


n=1 


1. By using enough splits (i.e., deep enough trees), we can make a piecewise linear approximation to decision boundaries 
with more complex shapes, but it may require a lot of data to fit such a model. 
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X,<h=5 
| 
Xy<Stp=7 X <u=3 
[| f 
Ri(y = 2) Xi <t =3 Ra(y = 4) R3(y = 6) 
Rly = 8) Rs(y = 10) 


(a) (b) 


Figure 18.1: (a) A regression tree on two inputs. (b) Corresponding piecewise constant surface, where the 
regions have heights 2, 4, 6, 8 and 10. Adapted from Figure 9.2 of [HTF09]. Generated by regtreeSur- 
faceDemo.ipynb. 


color 
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Figure 18.2: (a) A set of shapes with corresponding binary labels. The features are: color (values “blue”, “red”, 


“other”), shape (values “ellipse”, “other”), and size (real-valued). (b) A hypothetical classification tree fitted to 
this data. A leaf labeled as (nı, no) means that there are nı positive exramples that fall into this partition, and 
no negative examples. 


and @ = {(R;,w;):j =1: J}, where J is the number of nodes. The regions themselves are defined 
by the feature dimensions that are used in each split, and the corresponding thresholds, on the 
path from the root to the leaf. For example, in Figure 18.1(a), we have Ry = [(a1 < t1), (£2 < t2)I, 
R4 = [(a1 < t1), (£2 > t2), (£3 < t3)], etc. (For categorical inputs, we can define the splits based 
on comparing feature x; to each of the possible values for that feature, rather than comparing to a 
numeric threshold.) We discuss how to learn these regions in Section 18.1.2. 

For classification problems, the leaves contain a distribution over the class labels, rather than just 
the mean response. See Figure 18.2 for an example of a classification tree. 
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18.1.2 Model fitting 


To fit the model, we need to minimize the following loss: 


N J 
LO) =X Uun f(E)= >> Uun w) (18.4) 


j=l tn ER; 


Unfortunately, this is not differentiable, because of the need to learn the discrete tree structure. 
Indeed, finding the optimal partitioning of the data is NP-complete [HR76]. The standard practice is 
to use a greedy procedure, in which we iteratively grow the tree one node at a time. This approach is 
used by CART [BFO84], C4.5 [Qui93], and ID3 [Qui86], which are three popular implementations 
of the method. 

The idea is as follows. Suppose we are at node i; let Di = { (£n, yn) E€ Ni} be the set of examples 
that reach this node. We will consider how to split this node into a left branch and right branch so 
as to minimize the error in each child subtree. 

If the j’th feature is a real-valued scalar, we can partition the data at node i by comparing to a 
threshold t. The set of possible thresholds 7; for feature j can be obtained by sorting the unique values 
of {x,;}. For example, if feature 1 has the values {4.5, —12, 72, —12}, then we set 7; = {—12, 4.5, 72}. 
For each possible threshold, we define the left and right splits, DE (j, t) = {(£n, Yn) € Ni: £n j < t} 
and DP, t) = {(n, Yn) E€ Ni: Inj > t}. 

If the j’th feature is categorical, with K; possible values, then we check if the feature is equal to 
each of those values or not. This defines a set of K; possible binary splits: D} (j, t) = {(£n, Yn) € Ni: 
Inj =t} and DR(j,t) = { (£n, yn) € Ni: £n, j #t}.) (Alternatively, we could allow for a multi-way 
split, as in Figure 18.2(b). However, this may cause data fragmentation, in which too little data 
might “fall” into each subtree, resulting in overfitting. Therefore it is more common to use binary 
splits.) 

Once we have computed DŁ (j, t) and DË (j, t) for each j and t at node i, we choose the best feature 
ji to split on, and the best value for that feature, t;, as follows: 


DE (j,t DE(i.t 
(Ji ti) =arg min min [Pr GÐ [Di (t) 


DE(j,t 18. 
jo ye D] p Pi 4) (a 


c(Dj (j, t)) + 


We now discuss the cost function c(D;) which is used to evaluate the cost of node i. For regression, 
we can use the mean squared error 


TE a S (un — 9)? (18.6) 
nEeD; 


where y = TDI Š nen, Yn is the mean of the response variable for examples reaching node i. 
For classification, we first compute the empirical distribution over class labels for this node: 


1 
Tic =| I n= 18. 
ĉie = Ty > (Yn = ©) (18.7) 


Given this, we can then compute the Gini index 
Cc 
Gi =o Aic(1 — ttc) = DO tie — SO a = 1 - e (18.8) 
c=1 c c 
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This is the expected error rate. To see this, note that fie is the probability a random entry in the 
leaf belongs to class c, and 1 — 7c is the probability it would be misclassified. 
Alternatively we can define cost as the entropy or deviance of the node: 


C 
H; = H(f;) = — Stic log Tic (18.9) 
c=1. 


A node that is pure (i.e., only has examples of one class) will have 0 entropy. 

Given one of the above cost functions, we can use Equation (18.5) to pick the best feature, and 
best threshold at each node. We then partition the data, and call the fitting algorithm recursively on 
each subset of the data. 


18.1.3 Regularization 


If we let the tree become deep enough, it can achieve 0 error on the training set (assuming no label 
noise), by partioning the input space into sufficiently small regions where the output is constant. 
However, this will typically result in overfitting. To prevent this, there are two main approaches. The 
first is to stop the tree growing process according to some heuristic, such as having too few examples 
at a node, or reaching a maximum depth. The second approach is to grow the tree to its maximum 
depth, where no more splits are possible, and then to prune it back, by merging split subtrees back 
into their parent (see e.g., [BA97b]). This can partially overcome the greedy nature of top-down tree 
growing. (For example, consider applying the top-down approach to the xor data in Figure 13.1: 
the algorithm would never make any splits, since each feature on its own has no predictive power.) 
However, forward growing and backward pruning is slower than the greedy top-down approach. 


18.1.4 Handling missing input features 


In general, it is hard for discriminative models, such as neural networks, to handle missing input 
features, as we discussed in Section 1.5.5. However, for trees, there are some simple heuristics that 
can work well. 

The standard heuristic for handling missing inputs in decision trees is to look for a series of “backup’ 
variables, which can induce a similar partition to the chosen variable at any given split; these can be 
used in case the chosen variable is unobserved at test time. These are called surrogate splits. This 
method finds highly correlated features, and can be thought of as learning a local joint model of the 
input. This has the advantage over a generative model of not modeling the entire joint distribution 
of inputs, but it has the disadvantage of being entirely ad hoc. A simpler approach, applicable to 
categorical variables, is to code “missing” as a new value, and then to treat the data as fully observed. 


? 


18.1.5 Pros and cons 


Tree models are popular for several reasons: 
e They are easy to interpret. 


e They can easily handle mixed discrete and continuous inputs. 
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Figure 18.8: (a) A decision tree of depth 2 fit to the iris data, using just the petal length and petal width 
features. Leaf nodes are color coded according to the majority class. The number of training samples that pass 
from the root to each node is shown inside each box, as well as how many of these values fall into each class. 
This can be normalized to get a distribution over class labels for each node. (b) Decision surface induced by 
(a). (c) Fit to data where we omit a single data point (shown by red star). (d) Ensemble of the two models in 
(b) and (c). Generated by dtree_ sensitivity.ipynb. 


They are insensitive to monotone transformations of the inputs (because the split points are based 
on ranking the data points), so there is no need to standardize the data. 


They perform automatic variable selection. 


They are relatively robust to outliers. 


They are fast to fit, and scale well to large data sets. 


They can handle missing input features. 
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However, tree models also have some disadvantages. The primary one is that they do not predict 
very accurately compared to other kinds of model. This is in part due to the greedy nature of the 
tree construction algorithm. 

A related problem is that trees are unstable: small changes to the input data can have large effects 
on the structure of the tree, due to the hierarchical nature of the tree-growing process, causing errors 
at the top to affect the rest of the tree. For example, consider the tree in Figure 18.3b. Omitting 
even a single data point from the training set can result in a dramatically different decision surface, 
as shown in Figure 18.3c, due to the use of axis parallel splits. (Omitting features can also cause 
instability.) In Section 18.3 and Section 18.4, we will turn this instability into a virtue. 


18.2 Ensemble learning 


In Section 18.1, we saw that decision trees can be quite unstable, in the sense that their predictions 
might vary a lot if the training data is perturbed. In other words, decision trees are a high variance 
estimator. A simple way to reduce variance is to average multiple models. This is called ensemble 
learning. The result model has the form 


1 
fle) = Wa XO fm(ylz) (18.10) 


mEM 


where fm is the m’th base model. The ensemble will have similar bias to the base models, but lower 
variance, generally resulting in improved overall performance (see Section 4.7.6.3 for details on the 
bias-variance tradeoff). 

Averaging is a sensible way to combine predictions from regression models. For classifiers, it can 
sometimes be better to take a majority vote of the outputs. (This is sometimes called a committee 
method.) To see why this can help, suppose each base model is a binary classifier with an accuracy 
of 0, and suppose class 1 is the correct class. Let Ym € {0,1} be the prediction for the m’th model, 
and let S = ig Ym be the number of votes for class 1. We define the final predictor to be the 
majority vote, i.e., class 1 if S > M/2 and class 0 otherwise. The probability that the ensemble will 
pick class 1 is 


p= Pr(S > M/2) = 1 — B(M/2, M,6) (18.11) 


where B(x, M,6) is the cdf of the binomial distribution with parameters M and 0 evaluated at x. 
For 0 = 0.51 and M = 1000, we get p = 0.73 and with M = 10,000 we get p = 0.97. 

The performance of the voting approach is dramatically improved, because we assumed each 
predictor made independent errors. In practice, their mistakes may be correlated, but as long as we 
ensemble sufficiently diverse models, we can still come out ahead. 


18.2.1 Stacking 


An alternative to using an unweighted average or majority vote is to learn how to combine the base 
models, by using 


f(ylz) = X wmfm(yle) (18.12) 


mEM 
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This is called stacking, which stands for “stacked generalization” [Wol92]. Note that the combination 
weights used by stacking need to be trained on a separate dataset, otherwise they would put all their 
mass on the best performing base model. 


18.2.2 Ensembling is not Bayes model averaging 


It is worth noting that an ensemble of models is not the same as using Bayes model averaging over 
models (Section 4.6), as pointed out in [Min00]. An ensemble considers a larger hypothesis class of 
the form 


p(y|x,w,0) = X` wmp(y|x, Om) (18.13) 
meM 


whereas BMA uses 


p(ylx,D) = $` p(m|D)p(y|x,m,D) (18.14) 
mEM 


The key difference is that in the case of BMA, the weights p(m|D) sum to one, and in the limit of 
infinite data, only a single model will be chosen (namely the MAP model). By contrast, the ensemble 
weights Wm are arbitrary, and don’t collapse in this way to a single model. 


18.3 Bagging 


In this section, we discuss bagging |Bre96], which stands for “bootstrap aggregating”. This is a 
simple form of ensemble learning in which we fit M different base models to different randomly 
sampled versions of the data; this encourages the different models to make diverse predictions. The 
datasets are sampled with replacement (a technique known as bootstrap sampling, Section 4.7.3), so 
a given example may appear multiple times, until we have a total of N examples per model (where 
N is the number of original data points). 

The disadvantage of bootstrap is that each base model only sees, on average, 63% of the unique 
input examples. To see why, note that the chance that a single item will not be selected from a set 
of size N in any of N draws is (1—1/N)¥%. In the limit of large N, this becomes e~! ~ 0.37, which 
means only 1 — 0.37 = 0.63 of the data points will be selected. 

The 37% of the training instances that are not used by a given base model are called out-of-bag 
instances (oob). We can use the predicted performance of the base model on these oob instances as 
an estimate of test set performance. This provides a useful alternative to cross validation. 

The main advantage of bootstrap is that it prevents the ensemble from relying too much on any 
individual training example, which enhances robustness and generalization [Gra04]. For example, 
comparing Figure 18.3b and Figure 18.3c, we see that omitting a single example from the training set 
can have a large impact on the decision tree that we learn (even though the tree growing algorithm 
is otherwise deterministic). By averaging the predictions from both of these models, we get the more 
reasonable prediction model in Figure 18.3d. This advantage generally increases with the size of the 
ensemble, as shown in Figure 18.4. (Of course, larger ensembles take more memory and more time.) 

Bagging does not always improve performance. In particular, it relies on the base models being 
unstable estimators, so that omitting some of the data significantly changes the resulting model fit. 
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Decision Tree, test accuracy=0.86 Bag of 10 decision trees, test accuracy=0.91 


Figure 18.4: (a) A single decision tree. (b-c) Bagging ensemble of 10 and 50 trees. (d) Random forest of 50 
trees. Adapted from Figure 7.5 of [Gér19]. Generated by bagging_trees.ipynb and rf_demo_ 2d.ipynb. 


This is the case for decision trees, but not for other models, such as nearest neighbor classifiers. For 
neural networks, the story is more mixed. They can be unstable wrt their training set. On the other 
hand, deep networks will underperform if they only see 63% of the data, so bagged DNNs do not 
usually work well [NTL20]. 


18.4 Random forests 


Bagging relies on the assumption that re-running the same learning algorithm on different subsets of 
the data will result in sufficiently diverse base models. The technique known as random forests 
[Bre01] tries to decorrelate the base learners even further by learning trees based on a randomly 
chosen subset of input variables (at each node of the tree), as well as a randomly chosen subset of 
data cases. It does this by modifying Equation (18.5) so the the feature split dimension j is optimized 
over a random subset of the features, S; C {1,..., D}. 

For example, consider the email spam dataset [HTF09, p301]. This dataset contains 4601 email 
messages, each of which is classified as spam (1) or non-spam (0). The data was open sourced by 
George Forman from Hewlett-Packard (HP) Labs. 

There are 57 quantitative (real-valued) features, as follows: 


e 48 features corresponding to the percentage of words in the email that match a given word, such 
as “remove” or “labs”. 
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Figure 18.5: Preditive accuracy vs size of tree ensemble for bagging, random forests and gradient boosting 
with log loss. Adapted from Figure 15.1 of [HTF09]. Generated by spam_tree_ensemble_ compare. ipynb. 


e 6 features corresponding to the percentage of characters in the email that match a given character, 
namely; . [!$# 


e 3 features corresponding to the average length, max length, and sum of lengths of uninterrupted 
sequences of capital letters. (These features are called CAPAVE, CAPMAX and CAPTOT.) 


Figure 18.5 shows that random forests work much better than bagged decision trees, because many 
input features are irrelevant. (We also see that a method called “boosting”, discussed in Section 18.5, 
works even better; however, this requires sequentially fitting trees, whereas random forests can be fit 
in parallel.) 


18.5 Boosting 


Ensembles of trees, whether fit by bagging or the random forest algorithm, corresponding to a model 
of the form 


M 
f(@;0) = X` Bm Fm (2; Om) (18.15) 


where Fm is the m’th tree, and Bm is the corresponding weight, often set to Bm = 1/M. We can 
generalize this by allowing the Fm functions to be general function approximators, such as neural 
networks, not just trees. The result is called an additive model [HTF09]. We can think of this as a 
linear model with adaptive basis functions. The goal, as usual, is to minimize the empirical loss 
(with an optional regularizer): 


N 
LI) = X kis f(s) (18.16) 


Boosting [Sch90; FS96] is an algorithm for sequentially fitting additive models where each Fm 
is a binary classifier that returns Fm € {—1, +1}. In particular, we first fit F on the original data, 
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and then we weight the data samples by the errors made by Fi, so misclassified examples get more 
weight. Next we fit F to this weighted data set. We keep repeating this process until we have fit 
the desired number M of components. (M is a hyper-parameter that controls the complexity of the 
overall model, and can be chosen by monitoring performance on a validation set, and using early 
stopping.) 

It can be shown that, as long as each Fm has an accuracy that is better than chance (even on the 
weighted dataset), then the final ensemble of classifiers will have higher accuracy than any given 
component. That is, if Fm is a weak learner (so its accuracy is only slightly better than 50%), 
then we can boost its performance using the above procedure so that the final f becomes a strong 
learner. (See e.g., [SF 12] for more details on the learning theory approach to boosting.) 

Note that boosting reduces the bias of the strong learner, by fitting trees that depend on each 
other, whereas bagging and RF reduce the variance by fitting independent trees. In many cases, 
boosting can work better. See Figure 18.5 for an example. 

The original boosting algorithm focused on binary classification with a particular loss function 
that we will explain in Section 18.5.3, and was derived from the PAC learning theory framework 
(see Section 5.4.4). In the rest of this section, we focus on a more statistical version of boosting, 
due to [FHT00; Fri01], which works with arbitrary loss functions, making the method suitable for 
regression, multi-class classification, ranking, etc. Our presentation is based on [HTF09, ch10] and 
[BH07], which should be consulted for further details. 


18.5.1 Forward stagewise additive modeling 


In this section, we discuss forward stagewise additive modeling, in which we sequentially 
optimize the objective in Equation (18.16) for general (differentiable) loss functions, where f is an 
additive model as in Equation 18.15. That is, at iteration m, we compute 


N 
(Bm; Om) = argmin $, (us: fm—1(@:) + BF (x; 0)) (18.17) 
8 i=l 
We then set 
fm(x) = fm-1(£) + Bm F(T; Om) = fm-1(£) F BmFm(£) (18.18) 


(Note that we do not adjust the parameters of previously added models.) The details on how to 
perform this optimization step depend on the loss function that we choose, and (in some cases) on 
the form of the weak learner F, as we discuss below. 


18.5.2 Quadratic loss and least squares boosting 


Suppose we use squared error loss, ¢(y, 0) = (y — ĝ)?. In this case, the i’th term in the objective at 
step m becomes 


L(Yi, fm—1(@4) + BF (ae; )) = (yi — fm-1 (as) — BF (ai; 6))? = (Tim — BF (£i; 6))? (18.19) 


where rim = Yi — fm—1(@i) is the residual of the current model on the ith observation. We can 
minimize the above objective by simply setting 6 = 1, and fitting F to the residual errors. This is 
called least squares boosting [BY03]. 
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Figure 18.6: Illustration of boosting using a regression tree of depth 2 applied to a 1d dataset. Adapted from 
Figure 7.9 of [Gér19]. Generated by boosted_regr_trees.ipynb. 


We give an example of this process in Figure 18.6, where we use a regression tree of depth 2 as the 
weak learner. On the left, we show the result of fitting the weak learner to the residuals, and on the 
right, we show the current strong learner. We see how each new weak learner that is added to the 
ensemble corrects the errors made by earlier versions of the model. 


18.5.3 Exponential loss and AdaBoost 


Suppose we are interested in binary classification, i.e., predicting J; E {—1, +1}. Let us assume the 
weak learner computes 


ef (@) 1 


ply = 1|x) = c F) + cF) = i a oF @) (18.20) 


so F(x) returns half the log odds. We know from Equation (10.13) that the negative log likelihood is 
given by 


&(G, F(a)) = log(1 + e-7#")) (18.21) 
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— 0-1 loss 
Me NRE hinge loss 
25 ee —-- log loss 
Ro — exp loss 


Figure 18.7: Illustration of various loss functions for binary classification. The horizontal axis is the margin 
m(x) = yF (a), the vertical axis is the loss. The log loss uses log base 2. Generated by hinge_loss_ plot.ipynb. 


We can minimize this by ensuring that the margin m(x) = JF (x) is as large as possible. We see 
from Figure 18.7 that the log loss is a smooth upper bound on the 0-1 loss. We also see that it 
penalizes negative margins more heavily than positive ones, as desired (since positive margins are 
already correctly classified). 

However, we can also use other loss functions. In this section, we consider the exponential loss 


(9, F(æ)) = exp(-9F (a)) (18.22) 


We see from Figure 18.7 that this is also a smooth upper bound on the 0-1 loss. In the population 
setting (with infinite sample size), the optimal solution to the exponential loss is the same as for log 
loss. To see this, we can just set the derivative of the expected loss (for each a) to zero: 


Ə lr a... ee . 
ae fe Ir |x| = zF PU = Mele Pæ) + p(§ = —1\x)e? (18.23) 
= —p(§ = læ)" ® + p(ğ = —1x)e*) (18.24) 

p(y = tæ) 2F (x) 
pg=—lle) “ aa 


However, it turns out that the exponential loss is easier to optimize in the boosting setting, as we 
show below. (We consider the log loss case in Section 18.5.4.) 

We now discuss how to solve for the m’th weak learner, Fm, when we use exponential loss. We 
will assume that the base classifier Fm returns a binary class label; the resulting algorithm is called 
discrete AdaBoost [FHTOO]. If Fm returns a probability instead, a modified algorithm, known as 
real AdaBoost, can be used [FHTOO]. 


At step m we have to minimize 


N 


N 
Lm(F) = X exp[—di( fm—1(#i) + BF (2:))] = X` wim exp(—BGiF (ai) (18.26) 


i=l i=1 
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where Wim = exp(—Ji fim—1(#i)) is a weight applied to datacase i, and f; € {—1, +1}. We can rewrite 
this objective as follows: 


i a ae a XO wim (18.27) 


9i=F (wi) ViFF (x1) 
N N 
= (e? — eP) So wiml Gi # F(ai)) +e? X wim (18.28) 
i=1 i=1 
Consequently the optimal function to add is 
N 
Fm = argmin X` wi,mI (Gi # F(a#i)) (18.29) 
Piel 


This can be found by applying the weak learner to a weighted version of the dataset, with weights 
Wim: 

All that remains is to solve for the size of the update, 6. Subsituting Fm into Lm and solving for 
B we find 


1 1 — err 
m = = log ———™ 18.30 
Bm = 5 log a (18.30) 
where 
N = 
; imI (Yi A Fim (xi 
erim = doin Mi, TG # (wi)) (18.31) 
int wim 
Therefore overall update becomes 
fim (2) = fm-1(®) F Bm Fim (a) (18.32) 


After updating the strong learner, we need to recompute the weights for the next iteration, as 
follows: 


Wim+1 = evi tm (xi) — eti fm-1(2i)— Gi Bm Fm (2) = wi me vibm Em (ei) (18.33) 


If Ji = Fin(a;), then JiFm(£x:;) = 1, and if Ji A Fm (xi), then 9; Fin(a;) = —1. Hence —J;Fm(xi) = 
21 (Ji # Fm(£0)) — 1, so the update becomes 


tia = wi meh” AUG: A Fm (wi) —1) = wi me2Pml Gi Fm (@i)) e— Bm (18.34) 
Since the e~° is constant across all examples, it can be dropped. If we then define am = 28m, the 
update becomes 


d O Wi mer” if Yi Æ Fn (xi) 
4,m41 — z 
Wii otherwise 


(18.35) 


Thus we see that we exponentially increase weights of misclassified examples. The resulting algorithm 
shown in Algorithm 18.1, and is known as Adaboost.M1 [FS96]. 

A multiclass generalization of exponential loss, and an adaboost-like algorithm to minimize it, 
known as SAMME (stagewise additive modeling using a multiclass exponential loss function), is 
described in [Has+09]. This is implemented in scikit learn (the AdaBoostClassifier class). 
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Algorithm 18.1: Adaboost.M1, for binary classification with exponential loss 


1 Wi = 1/N 

2 for m =1: M do 

3 Fit a classifier F(a) to the training set using weights w 
N 4 ahs az; 

4 Compute errm = Diet tingin (ws) 


Dk Wim 
5 Compute am = log[(1 — errm)/errm| 
6 Set w; < wi explaml (Ji 4 Fim(x:i))| 


7 Return f(x) = sgn ae Om Fim (x) 


18.5.4 LogitBoost 


The trouble with exponential loss is that it puts a lot of weight on misclassified examples, as is 
apparent from the exponential blowup on the left hand side of Figure 18.7. This makes the method 
very sensitive to outliers (mislabeled examples). In addition, e~¥/ is not the logarithm of any pmf 
for binary variables ğ € {—1, +1}; consequently we cannot recover probability estimates from f(a). 

A natural alternative is to use log loss, as we discussed in Section 18.5.3. This only punishes 
mistakes linearly, as is clear from Figure 18.7. Furthermore, it means that we will be able to extract 
probabilities from the final learned function, using 


ef (æ) 1 
The goal is to minimze the expected log-loss, given by 
N 
Lm(F) = X log [1 + exp (—2(fm—1(#) + F(a:)))] (18.37) 
i=1 


By performing a Newton update on this objective (similar to IRLS), one can derive the algorithm 
shown in Algorithm 18.2. This is known as logitBoost [FHT00]. The key subroutine is the ability 
of the weak learner F to solve a weighted least squares problem. This method can be generalized to 
the multi-class setting, as explained in [FHT00]. 


18.5.5 Gradient boosting 


Rather than deriving new versions of boosting for every different loss function, it is possible to derive 
a generic version, known as gradient boosting [Fri01; Mas+00]. To explain this, imagine solving 
f = argmin L(f) by performing gradient descent in the space of functions. Since functions are infinite 


dimensional objects, we will represent them by their values on the training set, f = (f(a1),..., f(a@w)). 
At step m, let gm be the gradient of £(f) evaluated at f = fm-1: 
oe i; i 
Jim = [e] (18.38) 
Of (xi) f=fim—1 
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Algorithm 18.2: LogitBoost, for binary classification with log-loss 
1 Wi = 1/N, Ti = 1/2 
2 for m=1: M do 


3 Compute the working response z; = ats 

4 Compute the weights w; = 7;(1 — 7;) 

5 | Fm =argming D wilzi — F(a:))? 

6 Update f(x) + f(x) + $Fin(x) 

7 Compute m; = 1/(1 + exp(—2f(ax;))); 

s Return f(a) = sgn boa Fn (2) 
Name Loss —Ol(y:, f(xi))/OF (xi) 
Squared error (yi — f(e) yi- F(a) 
Absolute error lyi — f(x) sgn(y; — f(x:;)) 


Exponential loss exp(—ĝıf(x:)) —Jiexp(—Jif(x:ı)) 
Binary Logloss log(L+e%Fi) yi — Ti 
Multiclass logloss — >. Yic log Tie Yie — Tic 


Table 18.1: Some commonly used loss functions and their gradients. For binary classification problems, we 
assume Ji E {—1, +1}, and mi = o(2f(axi)). For regression problems, we assume y; E€ R. Adapted from 
[HTF09, p360] and [BH07, p483]. 


Gradients of some common loss functions are given in Table 18.1. We then make the update 


where Bm is the step length, chosen by 


Bm = argmin L(fm-1 = Bam) (18.40) 
B 


In its current form, this is not much use, since it only optimizes f at a fixed set of N points, so we 
do not learn a function that can generalize. However, we can modify the algorithm by fitting a weak 
learner to approximate the negative gradient signal. That is, we use this update 


N 
F,, = argmin — gim — F(ax;))? 18.41 
gn So(-9 (x;)) ( ) 


i=1 


The overall algorithm is summarized in Algorithm 18.3. We have omitted the line search step for Bm, 
which is not strictly necessary, as argued in [BH07]. However, we have introduced a learning rate or 
shrinkage factor 0 < v < 1, to control the size of the updates, for regularization purposes. 

If we apply this algorithm using squared loss, we recover L2Boosting, since —gim = Yi — fm—1(@i) 
is just the residual error. We can also apply this algorithm to other loss functions, such as absolute 
loss or Huber loss (Section 5.1.5.3), which is useful for robust regression problems. 
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Algorithm 18.3: Gradient boosting 


1 Initialize fo(x) = argmin p S L(yi, F(x:)) 
2 for m= 1: M do 


3 Compute the gradient residual using rim = — [pe] 
i i OM@) | fæi fm-i (ws) 
4 Use the weak learner to compute Fm = argmin p EE (rim — F(x;)}? 


i 


Update fm(x) = fm—1(@) + vFm(£) 
6 Return f(x) = f(x) 


For classification, we can use log-loss. In this case, we get an algorithm known as BinomialBoost 
[BH07]. The advantage of this over LogitBoost is that it does not need to be able to do weighted 
fitting: it just applies any black-box regression model to the gradient vector. To apply this to 
multi-class classification, we can fit C separate regression trees, using the pseudo residual of the form 


_ Olly: fim(@i),---,fom(@i)) _ l 
“iem = A 5 (w = 0) = me 084) 


Although the trees are fit separately, their predictions are combined via a softmax transform 


efe (æ) 
ply = cæ) = SAP 
When we have large datasets, we can use a stochastic variant in which we subsample (without 

replacement) a random fraction of the data to pass to the regression tree at each iteration. This is 
called stochastic gradient boosting |Fri99]. Not only is it faster, but it can also generalize better, 
because subsampling the data is a form of regularization. 


(18.43) 


18.5.5.1 Gradient tree boosting 


In practice, gradient boosting nearly always assumes that the weak learner is a regression tree, which 
is a model of the form 


Jm 
Fal(£) = X wml (@ € Rim) (18.44) 


j=1 


where Wjm is the predicted output for region Rjm. (In general, Wim could be a vector.) This 
combination is called gradient boosted regression trees, or gradient tree boosting. (A 
related version is known as MART, which stands for “multivariate additive regression trees” [FM03].) 

To use this in gradient boosting, we first find good regions Rjm for tree m using standard regression 
tree learning (see Section 18.1) on the residuals; we then (re)solve for the weights of each leaf by 
solving 


Wim =argmin X` Lyi fm—1(@i) +w) (18.45) 


LiERjm 
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For squared error (as used by gradient boosting), the optimal weight tjm is the just the mean of the 
residuals in that leaf. 


18.5.5.2 XGBoost 


XGBoost (https://github.com/dmlc/xgboost), which stands for “extreme gradient boosting”, is 
a very efficient and widely used implementation of gradient boosted trees, that adds a few more 
improvements beyond the description in Section 18.5.5.1. The details can be found in [CG16], but 
in brief, the extensions are as follows: it adds a regularizer on the tree complexity, it uses a second 
order approximation of the loss (from [FHT00]) instead of just a linear approximation, it samples 
features at internal nodes (as in random forests), and it uses various computer science methods (such 
as handling out-of-core computation for large datasets) to ensure scalability.” 
In more detail, XGBoost optimizes the following regularized objective 


N 
Lf) = Delis Fæ) + AP) (18.46) 
where 
1 J 
Qf) = + 5A) u; (18.47) 


is the regularizer, where J is the number of leaves, and y > 0 and A > 0 are regularization coefficients. 
At the m’th step, the loss is given by 


N 
L£in(Fim) = So lu, fm—1(@i) + Fn (@i)) + Q(Fm) + const (18.48) 


i=1 
We can compute a second order Taylor expansion of this as follows: 


N 
1 
LmlFm) = X fet fm-1(2£i)) + JimFm(£i) + 5 him Fm (i) + 2(Fm) + const (18.49) 
i=1 


where him is the Hessian 


are i) i 
Of (xi) f=fm-1 
In the case of regression trees, we have F(x) = wae), where q : R? — {1,..., J} specifies which 


leaf node æ belongs to, and w € R7 are the leaf weights. Hence we can rewrite Equation (18.49) as 


2. Some other popular gradient boosted trees packages are CatBoost (https://catboost.ai/) and LightGBM 
(https: //github.com/Microsoft/LightGBM). 
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follows, dropping terms that are independent of Fm: 


N J 
1 
mS |sinFn (x) + shim F2 («| tJ + sd) wy (18.51) 
i=1 1 
=Y (© sims + 50 Bat Nuh] bad (18.52) 
j=l | tel; icl 


where I; = {i : q(x;) = j} is the set of indices of data points assigned to the j’th leaf. 


Let us define Gym, = Dier Jim and Hjm = Dcr him. Then the above simplifies to 


J 
1 
Lm(q,w) = >> [imu + 5 (Him + dyu3| +7J (18.53) 
j=l 
This is a quadratic in each w;, so the optimal weights are given by 


Gjim 
= 18.54 
u=- (18.54) 


The loss for evaluating different tree structures q then becomes 


Lmlq, w*) = 5 > Te z TPU (18.55) 


We can greedily optimize this using a recursive node splitting procedure, as in Section 18.1. 
Specifically, for a given leaf j, we consider splitting it into a left and right half, J = Iz U Ir. We can 
compute the gain (reduction in loss) of such a split as follows: 


FNN E OG (G+ Ga? 
8 I H, +å HgR+À (Ht + Hr) +A 


(18.56) 


where Gr = Dier Jim, GR = iis Jim, Hy = Vier him, and Hr = Freis him. Thus we see 
that it is not worth splitting a node if the gain is negative (i.e., the first term is less than y). 

A fast approximation for evaluating this objective, that does not require sorting the features (for 
choosing the optimal threshold to split on), is described in [CG16]. 


18.6 Interpreting tree ensembles 
Trees are popular because they are interpretable. Unfortunately, ensembles of trees (whether in the 
form of bagging, random forests, or boosting) lose that property. Fortunately, there are some simple 


methods we can use to interpret what function has been learned. 
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Very important 


Not important 


Figure 18.8: Feature importance of a random forest classifier trained to distinguish MNIST digits from classes 
0 and 8. Adapted from Figure 7.6 of [Gér19]. Generated by rf_feature_ importance _ mnist.ipynb. 


18.6.1 Feature importance 


For a single decision tree T , [BFO84] proposed the following measure for feature importance of 
feature k: 


J—1 
R(T) =X Gj (vj = k) (18.57) 


where the sum is over all non-leaf (internal) nodes, G; is the gain in accuracy (reduction in cost) at 
node j, and v; = k if node j uses feature k. We can get a more reliable estimate by averaging over 
all trees in the ensemble: 


hac Se) (18.58) 


After computing these scores, we can normalize them so the largest value is 100%. We give some 
examples below. 

Figure 18.8 gives an example of estimating feature importance for a classifier trained to distinguish 
MNIST digits from classes 0 and 8. We see that it focuses on the parts of the image that differ 
between these classes. 

In Figure 18.9, we plot the relative importance of each of the features for the spam dataset 
(Section 18.4). Not surprisingly, we find that the most important features are the words “george” 
(the name of the recipient) and “hp” (the company he worked for), as well as the characters ! and §. 
(Note it can be the presence or absence of these features that is informative.) 
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credit 
3d 
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Figure 18.9: Feature importance of a gradient boosted classifier trained to distinguish spam from non-spam 
email. The dataset has X training examples with Y features, corresponding to token frequency. Adapted from 
Figure 10.6 of [HTF09]. Generated by spam_tree_ensemble_interpret.ipynb. 
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Figure 18.10: (a) Partial dependence of log-odds of the spam class for 4 important predictors. The red 
ticks at the base of the plot are deciles of the empirical distribution for this feature. (b) Joint partial 
dependence of log-odds on the features hp and !. Adapted from Figure 10.6-10.8 of [HTF09]. Generated by 
spam_tree_ensemble_interpret.ipynb. 
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18.6.2 Partial dependency plots 


After we have identified the most relevant input features, we can try to assess the impact they have 
on the output. A partial dependency plot for feature k is a plot of 


1 N 
F(a) = a7 D2 SEn,- te) (18.59) 


n=1 


vs zk. Thus we marginalize out all features except k. In the case of a binary classifier, we can 
convert this to log odds, log p(y = 1|x,)/p(y = Olax), before plotting. We illustrate this for our spam 
example in Figure 18.10a for 4 different features. We see that as the frequency of ! and “remove” 
increases, so does the probability of spam. Conversely, as the frequency of “edu” or “hp” increases, 
the probability of spam decreases. 

We can also try to capture interaction effects between features 7 and k by computing 


N 
Fir(£j, £k) = + 5 f(®n,-jk, £j, £k) (18.60) 


n=1 


We illustrate this for our spam example in Figure 18.10b for hp and !. We see that higher frequency 
of ! makes it more likely to be spam, but much more so if the word “hp” is missing. 
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Beyond Supervised Learning 


1 9 Learning with Fewer Labeled Examples 


Many ML models, especially neural networks, often have many more parameters than we have 
labeled training examples. For example, a ResNet CNN (Section 14.3.4) with 50 layers has 23 million 
parameters. Transformer models (Section 15.5) can be even bigger. Of course these parameters are 
highly correlated, so they are not independent “degrees of freedom”. Nevertheless, such big models 
are slow to train and, more importantly, they may easily overfit. This is particularly a problem 
when you do not have a large labeled training set. In this chapter, we discuss some ways to tackle 
this issue, beyond the generic regularization techniques we discussed in Section 13.5 such as early 
stopping, weight decay and dropout. 


19.1 Data augmentation 


Suppose we just have a single small labeled dataset. In some cases, we may be able to create 
artificially modified versions of the input vectors, which capture the kinds of variations we expect to 
see at test time, while keeping the original labels unchanged. This is called data augmentation. ! 
We give some examples below, and then discuss why this approach works. 


19.1.1 Examples 


For image classification tasks, standard data augmentation methods include random crops, zooms, 
and mirror image flips, as illustrated in Figure 19.1. [GVZ16] gives a more sophisticated example, 
where they render text characters onto an image in a realistic way, thereby creating a very large 
dataset of text “in the wild”. They used this to train a state of the art visual text localization and 
reading system. Other examples of data augmentation include artifically adding background noise to 
clean speech signals, and artificially replacing characters or words at random in text documents. 

If we afford to train and test the model many times using different versions of the data, we can 
learn which augmentations work best, using blackbox optimization methods such as RL (see e.g., 
[Cub+19]) or Bayesian optimization (see e.g., [Lim+19]); this is called AutoAugment. We can also 
learn to combine multiple augmentations together; this is called AutoAugment [Cub+19]. 

For some examples of augmentation in NLP, see e.g., [Fen+21]. 


1. The term “data augmentation” is also used in statistics to mean the addition of auxiliary latent variables to a model 
in order to speed up convergence of posterior inference algorithms [DM01]. 
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Figure 19.1: Illustration of random crops and zooms of a image images. Generated by im- 
age_augmentation_jax.ipynb. 


19.1.2 Theoretical justification 


Data augmentation often significantly improves performance (predictive accuracy, robustness, etc). At 
first this might seem like we are getting something for nothing, since we have not provided additional 
data. However, the data augmentation mechanism can be viewed as a way to algorithmically inject 
prior knowledge. 

To see this, recall that in standard ERM training, we minimize the empirical risk 


f= | Ure) (x, y)dady (19.1) 


where we approximate p*(a, y) by the empirical distribution 


1 N 
0) = LHe wnd = w) (19.2) 


We can think of data augmentation as replacing the empirical distribution with the following 
algorithmically smoothed distribution 


where A is the data augmentation algorithm, which generates a sample x from a training point £n, 
such that the label (“semantics”) is not changed. (A very simple example would be a Gaussian kernel, 
p(x\an,A) = N (x|£n,0°I).) This has been called vicinal risk minimization [Cha+01], since we 
are minimizing the risk in the vicinity of each training point æ. For more details on this perspective, 
see [Zha+17b; CDL19; Dao+19]. 


19.2 Transfer learning 


This section is coauthored with Colin Raffel. 
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Source Target 
model model 
Random Train from 
Output layer | initialization `` } Sich 
Layer L- 1 ---- copy ---->| LayerL-1 
Pre-train 
Ren copy --- -> Fine-tune 


- copy +++- 


ee 


Figure 19.2: Illustration of fine-tuning a model on a new dataset. The final output layer is trained from 
scratch, since it might correspond to a different label set. The other layers are initialized at their previous 
parameters, and then optionally updated using a small learning rate. From Figure 13.2.1 of [Zha+20]. Used 
with kind permission of Aston Zhang. 


Many data-poor tasks have some high-level structural similarity to other data-rich tasks. For 
example, consider the task of fine-grained visual classification of endangered bird species. Given 
that endangered birds are by definition rare, it is unlikely that a large quantity of diverse labeled 
images of these birds exist. However, birds bear many structural similarities across species - for 
example, most birds have wings, feathers, beaks, claws, etc. We therefore might expect that first 
training a model on a large dataset of non-endangered bird species and then continuing to train it on 
a small dataset of endangered species could produce better performance than training on the small 
dataset alone. 

This is called transfer learning, since we are transferring information from one dataset to another, 
via a shared set of parameters. More precisely, we first perform a pre-training phase, in which we 
train a model with parameters @ on a large source dataset D,; this may be labeled or unlabeled. 
We then perform a second fine-tuning phase on the small labeled target dataset D, of interest. 
We discuss these two phases in more detail below, but for more information, see e.g., [Tan+18; 
Zhu+21] for recent surveys. 


19.2.1 Fine-tuning 


Suppose, for now, that we already have a pretrained classifier, p(y|æ, 0p), such as a CNN, that works 
well for inputs x € Xp (e.g. natural images) and outputs y € Vp (e.g., ImageNet labels), where the 
data comes from a distribution p(a, y) similar to the one used in training. Now we want to create a 
new model q(y|x, 04) that works well for inputs x € æ; (e.g. bird images) and outputs y € Y; (eg., 
fine-grained bird labels), where the data comes from a distribution q(x, y) which may be different 
from p. 

We will assume that the set of possible inputs is the same, so X4 ~ Xp (e.g., both are RGB images), 
or that we can easily transform inputs from domain p to domain q (e.g., we can convert an RGB 
image to grayscale by dropping the chrominance channels and just keeping luminance). (If this is not 
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Figure 19.3: (a) Adding adapter layers to a transformer. From Figure 2 of [How+19]. Used with kind 
permission of Neil Houlsby. (b) Adding adapter layers to a resnet. From Figure 2 of [RBV18]. Used with 
kind permission of Sylvestre-Alvise Rebuffi. 


the case, then we may need to use a method called domain adaptation, that modifies models to map 
between modalities, as discussed in Section 19.2.5.) 

However, the output domains are usually different, i.e., Yg Æ Yp. For example, Yp might be 
Imagenet labels and Yı might be medical labels (e.g., types of diabetic retinopathy [Arc+19]). In 
this case, we need to “translate” the output of the pre-trained model to the new domain. This 
is easy to do with neural networks: we simply “chop off” the final layer of the original model, 
and add a new “head” to model the new class labels, as illustrated in Figure 19.2. For example, 
suppose p(y|xz, 0p) = softmax(y|W2h(x; 01) + b2), where 0, = (W2, bz, 4)). Then we can construct 
q(y|@q) = softmax(y|W3h(a; 01) + bs), where 0, = (W3, b3, 01) and h(a; 01) is the shared nonlinear 
feature extractor. 

After performing this “model surgery”, we can fine-tune the new model with parameters 0, = 
(01,03), where 0; parameterizes the feature extractor, and 03 parameterizes the final linear layer that 
maps features to the new set of labels. If we treat 0; as “frozen parameters’, then the resulting 
model q(y|x, 04) is linear in its parameters, so we have a convex optimization problem for which many 
simple and efficient fitting methods exist (see Part II). This is particularly helpful in the long-tail 
setting, where some classes are very rare [Kan+-20]. However, a linear “decoder” may be too limiting, 
so we can also allow 0; to be fine-tuned as well, but using a lower learning rate, to prevent the values 
moving too far from the values estimated on Dp. 


19.2.2 Adapters 


One disadvantage of fine-tuning all the model parameters of a pre-trained model is that it can be 
slow, since there are often many parameters, and we may need to use a small learning rate to prevent 
the low-level feature extractors from diverging too far from their prior values. In addition, every new 
task requires a new model to be trained, making task sharing hard. An alternative approach is to 
keep the pre-trained model untouched, but to add new parameters to modify its internal behavior to 
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customize the feature extraction process for each task. This idea is called adapters, and has been 
explored in several papers (e.g., [RBV17; RBV18; Hou+19]). 

Figure 19.3a illustrates adapters for transformer networks (Section 15.5), as proposed in [Hou+19]. 
The basic idea is to insert two shallow bottleneck MLPs inside each transformer layer, one after 
the multi-head attention and once after the feed-forward layers. Note that these MLPs have skip 
connections, so that they can be initialized to implement the identity mapping. If the transformer 
layer has features of dimensionality D, and the adapter uses a bottleneck of size M, this introduces 
O(DM) new parameters per layer. These adapter MLPs, as well as the layer norm parameters 
and final output head, are trained for each new task, but the all remaining parameters are frozen. 
Empirically on several NLP benchmarks, this is found to give better performance than fine tuning, 
while only needing about 1-10% of the original parameters. 

Figure 19.3b illustrates adapters for residual networks (Section 14.3.4), as proposed in [RBV17; 
RBV18]. The basic idea is to add a 1x1 convolution layer a, which is analogous to the MLP adapter 
in the transformer case, to the internal layers of the CNN. This can be added in series or in parallel, 
as shown in the diagram. If we denote the adapter layer by p(a), we can define the series adapter to 
be 


p(x) = x + diag, (a) ® x = diag (I+ a)@a (19.4) 


where diag; (œ) € R'*!*°* reshapes a matrix a € R°*? into a matrix that can be applied to each 
spatial location in parallel. (We have omitted batch normalization for simplicity.) If we insert this 
after a regular convolution layer f ® x we get 


y = p(f ® x)= (diag, (I+a)@f)®z (19.5) 


This can be interpreted as a low-rank multiplicative perturbation to the original filter f. The parallel 
adapter can be defined by 


y=f@®a+diag,(a)®a = (f + diag, (a))® æ (19.6) 


This can be interpreted as a low-rank additive perturbation to the original filter f. In both cases, 
setting œ = 0 ensures the adapter layers can be initialized to the identity transformation. In addition, 
both methods required O(C?) parameters per layer. 


19.2.3 Supervised pre-training 


The pre-training task may be supervised or unsupervised; the main requirements are that it can 
teach the model basic structure about the problem domain and that it is sufficiently similar to 
the downstream fine-tuning task. The notion of task similarity is not rigorously defined, but in 
practice the domain of the pre-training task is often more broad than that of the fine-tuning task 
(e.g., pre-train on all bird species and fine-tune on endangered ones). 

The most straightforward form of transfer learning is the case where a large labeled dataset is 
suitable for pre-training. For example, it is very common to use the ImageNet dataset (Section 1.5.1.2) 
to pretrain CNNs, which can then be used for an a variety of downstream tasks and datasets (see 
e.g., [Kol+19]). Imagenet has 1.28 million natural images, each associated with a label from one of 
1,000 classes. The classes constitute a wide variety of different concepts, including animals, foods, 
buildings, musical instruments, clothing, and so on. The images themselves are diverse in the sense 
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that they contain objects from many angles and in many sizes with a wide variety of backgrounds. 
This diversity and scale may partially explain why it has become a de-facto pre-training task for 
transfer learning in computer vision. (See finetune_cnn_jax.ipynb for some example code.) 

However, Imagenet pre-training has been shown to be less helpful when the domain of the fine- 
tuning task is quite different from natural images (e.g. medical images [Rag+19]). And in some cases 
where it is helpful (e.g., training object detection systems), it seems to be more of a speedup trick 
(by warm-starting optimization at a good point) rather than something that is essential, in the sense 
that one can achieve comparable performance on the downstream task when training from scratch, if 
done for long enough [HGD19]. 

Supervised pre-training is somewhat less common in non-vision applications. One notable exception 
is to pre-train on natural language inference data (i.e. whether a sentence implies or contradicts 
another) to learn vector representations of sentences [Con+17], though this approach has largely been 
supplanted by unsupervised methods (Section 19.2.4). Another non-vision application of transfer 
learning is to pre-train a speech recognition on a large English-labeled corpus before fine-tuning on 
low-resource languages [Ard+20]. 


19.2.4 Unsupervised pre-training (self-supervised learning) 


It is increasingly common to use unsupervised pre-training, because unlabeled data is often easy 
to acquire, e.g., unlabeled images or text documents from the web. 

For a short period of time it was common to pre-train deep neural networks using an unsupervised 
objective (e.g., reconstruction error, as discussed in Section 20.3) over the labeled dataset (i.e. ignoring 
the labels) before proceeding with standard supervised training [HOT06; Vin+10b; Erh+10]. While 
this technique is also called unsupervised pre-training, it differs from the form of pre-training for 
transfer learning we discuss in this section, which uses a (large) unlabeled dataset for pre-training 
before fine-tuning on a different (smaller) labeled dataset. 

Pre-training tasks that use unlabeled data are often called self-supervised rather than unsuper- 
vised. This term is used because the labels are created by the algorithm, rather than being provided 
externally by a human, as in standard supervised learning. Both supervised and self-supervised 
learning are discriminative tasks, since they require predicting outputs given inputs. By contrast, 
other unsupervised approaches, such as some of those discussed in Chapter 20, are generative, since 
they predict outputs unconditionally. 

There are many different self-supervised learning heuristics that have been tried (see e.g., [GR18; 
JT19; Ren19] for a review, and https: //github.com/jason718/awesome-self-supervised-learning 
for an extensive list of papers). We can identify at least three main broad groups, which we discuss 
below. 


19.2.4.1 Imputation tasks 


One approach to self-supervised learning is to solve imputation tasks. In this approach, we partition 
the input vector x into two parts, x = (£h, £v), and then try to predict the hidden part x, given 
the remaining visible part, x,, using a model of the form ĉn = f (£u, £n = 0). We can think of this 
as a “fill-in-the-blank” task; in the NLP community, this is called a cloze task. See Figure 19.4 
for some visual examples, and Section 15.7.2 for some NLP examples. 
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Figure 19.4: (a) Context encoder for self-supervised learning. From [Pat+16]. Used with kind permission 
of Deepak Pathak. (b) Some other proxy tasks for self-supervised learning. From [LeC18]. Used with kind 
permission of Yann LeCun. 


19.2.4.2 Proxy tasks 


Another approach to SSL is to solve proxy tasks, also called pretext tasks. In this setup, we 
create pairs of inputs, (#1, #2), and then train a Siamese network classifier (Figure 16.5a) of the 
form p(y|v1, £2) = ply|r|f (x1), f(x2)]), where f(a) is some function that performs “representation 
learning” [BCV13], and y is some label that captures the relationship between a, and x2, which 
is predicted by r(fi, f2). For example, suppose a; is an image patch, and a2 = t(a1) is some 
transformation of a, that we control, such as a random rotation; then we define y to be the rotation 
angle that we used [GSK18]. 


19.2.4.3 Contrastive tasks 


The currently most popular approach to self-supervised learning is to use various kinds of contrastive 
tasks. The basic idea is to create pairs of examples that are semantically similar to each other, 
using data augmentation methods (Section 19.1), and then to ensure that the distance between their 
representations is closer (in embedding space) than the distance between two unrelated examples. 
This is exactly the same idea that is used in deep metric learning (Section 16.2.2) — the only 
difference is that the algorithm creates its own similar pairs, rather than relying on an externally 
provided measure of similarity, such as labels. We give some examples of this in Section 19.2.4.4 and 
Section 19.2.4.5. 


19.2.4.4 SimCLR 


In this section, we discuss SimCLR, which stands for “Simple contrastive learning of visual repre- 
sentations” [Che+20b; Che+20c]. This has shown state of the art performance on transfer learning 
and semi-supervised learning. The basic idea is as follows. Each input « € RP is converted to 
two augmented “views’ #, = t (a), £2 = to(x), which are “semantically equivalent” versions of 
the input generated by some transformations t,,t2. For example, if x is an image, these could be 
small perturbations to the image, such as random crops, as discussed in Section 19.1. In addition, 
we sample “negative” examples x; ,..., £ E€ N(x) from the dataset which represent “semantically 
different” images (in practice, these are the other examples in the minibatch). Next we define some 
feature mapping F : R? > RË, where D is the size of the input, and E is the size of the embedding. 
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Figure 19.5: (a) Illustration of SimCLR training. T is a set of stochastic semantics-preserving transformations 
(data augmentations). (b-c) Illustration of the benefit of random crops. Solid rectangles represent the original 
image, dashed rectangles are random crops. In (b), the model is forced to predict the local view A from the 
global view B (and vice versa). In (c), the model is forced to predict the appearance of adjacent views (C,D). 
From Figures 2-3 of [Che+20b]. Used with kind permission of Ting Chen. 
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Figure 19.6: Visualization of SimCLR training. Each input image in the minibatch is randomly modified in 
two different ways (using cropping (followed by resize), flipping, and color distortion), and then fed into a 
Siamese network. The embeddings (final layer) for each pair derived from the same image is forced to be 
close, whereas the embeddings for all other pairs are forced to be far. From https: // ai. googleblog. com/ 
2020/ 04/ advancing-self-supervised-and-semi. html. Used with kind permission of Ting Chen. 
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We then try to maximize the similarity of the similar views, while minimizing the similarity of the 
different views, for each input æ: 


J = F(ty(#))"F(ta(@)) —log $, exp [F(a7)'F (ti (2))] (19.7) 


In practice, we use cosine similarity, so we £2-normalize the representations produced by F before 
taking inner products, but this is omitted in the above equation. See Figure 19.5a for an illustration. 
(In this figure, we assume F (æ) = g(r(x)), where the intermediate representation h = r(a) is the one 
that will be later used for fine-tuning, and g is an additional transformation applied during training.) 
Interestingly, we can interpret this as a form of conditional energy based model of the form 


exp[—€(x2|21)| 


Zee.) (19.8) 


p(æ2|x1) = 
where €(a2|a1) = —F(a2)' F(a1) is the energy, and 


Z(a) = J olea jda- = [ewiF@)'F@)lde- (19.9) 


is the normalization constant, known as the partition function. The conditional log likelihood 
under this model has the form 


log p(a2|a1) = F(a2)' F(a1) — log f expl P(e JT F(@1)]de7 (19.10) 


The only difference from Equation (19.7) is that we replace the integral with a Monte Carlo upper 
bound derived from the negative samples. Thus we can think of contrastive learning as approximate 
maximum likelihood estimation of a conditional energy based generative model [Gra+20]. More 
details on such models can be found in the sequel to this book, [Mur23]. 

A critical ingredient to the success of SimCLR is the choice of data augmentation methods. By 
using random cropping, they can force the model to predict local views from global views, as well as 
to predict adjacent views of the same image (see Figure 19.5). After cropping, all images are resized 
back to the same size. In addition, they randomly flip the image some fraction of the time.” 

SimCLR relies on large batch training, in order to ensure a sufficiently diverse set of negatives. 
When this is not possible, we can use a memory bank of past (negative) embeddings, which can 
be updated using exponential moving averaging (Section 4.4.2.2). This is known as momentum 
contrastive learning or MoCo [He+ 20]. 


19.2.4.5 CLIP 


In this section, we describe CLIP, which stands for “Contrastive Language-Image Pre-training” 
[Rad+]. This is a contrastive approach to representation learning which uses a massive corpus of 


2. It turns out that distinguishing positive crops (from the same image) from negative crops (from different images) is 
often easy to do just based on color histograms. To prevent this kind of “cheating”, they also apply a random color 
distortion, thus cutting off this “short circuit”. The combination of random cropping and color distortion is found to 
work better than either method alone. 
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Figure 19.7: Illustration of the CLIP model. From Figure 1 of [Rad+]. Used with kind permission of Alec 
Radford. 


400M (image, text) pairs extracted from the web. Let a; be the ith image and y; be its matching 
text. Rather than trying to predict the exact words associated with the image, it is simpler to just 
determine if y; is more likely to be the correct text compared to y;, for some other text string 7 in 
the minibatch. Similarly, the model can try to determine if image æ; is more likely to be matched 
than æ; to a given text yi. 

More precisely, let fr(æ;) be the embedding of the image, fr(y,;) be the embedding of the text, 
I; = fr(x:)/||fr(x:)||2 be the unit-norm version of the image embedding, and T; = fr(y;)/l|fr(y,)|l2 


be the unit-norm version of the text embedding. Define the vector of pairwise logits (similarity scores) 
to be 


Li; = IT} (19.11) 


We now train the parameters of the two embedding functions fr and fr to minimize the following 
loss, averaged over minibatches of size N: 


N N 
1 
T=, 2 CE(Li,;, 1;) + 2 CE(L.,;,1;) (19.12) 
Vw J= 
where CE is the cross entropy loss 
K 
CE(p, q) = — X pr log qx (19.13) 
k=1 


and 1; is a one-hot encoding of label i. See Figure 19.7a for an illustration. (In practice, the 
normalized embeddings are scaled by a temperature parameter which is also learned; this controls 
the sharpness of the softmax.) 
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In their paper, they considered using a ResNet (Section 14.3.4) and a vision transformer (Sec- 
tion 15.5.6) for the function fzr, and a text transformer (Section 15.5) for fr. They used a very large 
minibatch of N ~ 32k, and trained for many days on 100s of GPUs. 

After the model is trained, it can be used for zero-shot classification of an image æ as follows. 
First each of the K possible class labels for a given dataset is converted into a text string Yx that 
might occur on the web. For example, “dog” becomes “a photo of a dog”. Second, we compute the 
normalized embeddings I x f;(a) and Tk x fr(y,). Third, we compute the softmax probabilites 


p(y = k|æ) = softmax([I"T),..., 1 T,])x (19.14) 


See Figure 19.7b for an illustration. (A similar approach was adopted in the visual n-grams paper 
[Li+17].) 

Remarkably, this approach can perform as well as standard supervised learning on tasks such as 
ImageNet classification, without ever being explicitly trained on specific labeled datasets. Of course, 
the images in ImageNet come from the web, and were found using text-based web-search, so the 
model has seen similar data before. Nevertheless, its generalization to new tasks, and robustness to 
distribution shift, are quite impressive (see the paper for examples). 

One drawback of the approach, however, is that it is sensitive to how class labels are converted to 
textual form. For example, to make the model work on food classification, it is necessary to use text 
strings of the form “a photo of guacamole, a type of food”, “a photo of ceviche, a type of food”, etc. 
Disambiguating phrases such as “a type of food” are currently added by hand, on a per-dataset basis. 
This is called prompt engineering, and is needed since the raw class names can be ambiguous 
across (and sometimes within) a dataset. 


19.2.5 Domain adaptation 


Consider a problem in which we have inputs from different domains, such as a source domain 
Xs and target domain 4%, but a common set of output labels, VY. (This is the “dual” of transfer 
learning, since the input domains are different, but the output domains the same.) For example, the 
domains might be images from a computer graphics system and real images, or product reviews and 
movie reviews. We assume we do not have labeled examples from the target domain. Our goal is to 
fit the model on the source domain, and then modify its parameters so it works on the target domain. 
This is called (unsupervised) domain adaptation (see e.g., [KL21] for a review). 

A common approach to this problem is to train the source classifier in such a way that it cannot 
distinguish whether the input is coming from the source or target distribution; in this case, it will 
only be able to use features that are common to both domains. This is called domain adversarial 
learning [Gan+16]. More formally, let dn € {s,t} be a label that specifies if the data example n 
comes from domain s or t. We want to optimize 


: > Udy, folæn)) + 5 XO Uum, 96 fo(@m))) (19.15) 


min max ——— 
o 9 Nst+ Ne n€Ds,Dr mED, 
where N, = |D,|, Ne = |Dil, f maps ¥,U 4% > H, and g maps H > YVı. The objective in 
Equation (19.15) minimizes the loss on the desired task of classifying y, but maximizes the loss on 
the auxiliary task of classifying the source domain d. This can be implemented by the gradient sign 
reversal trick, and is related to GANs (generative adversarial networks). See e.g., [Csul7; Wu+ 19] 
for some other approaches to domain adaptation. 
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Figure 19.8: Illustration of the benefits of semi-supervised learning for a binary classification problem. Labeled 
points from each class are shown as black and white circles respectively. (a) Decision boundary we might 
learn given only labeled data. (b) Decision boundary we might learn if we also had a lot of unlabeled data 
points, shown as smaller grey circles. 


19.3 Semi-supervised learning 


This section is co-authored with Colin Raffel. 


Many recent successful applications of machine learning are in the supervised learning setting, 
where a large dataset of labeled examples are available for training a model. However, in many 
practical applications it is expensive to obtain this labeled data. Consider the case of automatic speech 
recognition: Modern datasets contain thousands of hours of audio recordings [Pan+15; Ard+20]. The 
process of annotating the words spoken in a recording is many times slower than realtime, potentially 
resulting in a long (and costly) annotation process. To make matters worse, in some applications 
data must be labeled by an expert (such as a doctor in medical applications) which can further 
increase costs. 

Semi-supervised learning can alleviate the need for labeled data by taking advantage of 
unlabeled data. The general goal of semi-supervised learning is to allow the model to learn the 
high-level structure of the data distribution from unlabeled data and only rely on the labeled data 
for learning the fine-grained details of a given task. Whereas in standard supervised learning we 
assume that we have access to samples from the joint distribution of data and labels x,y ~ p(x, y), 
semi-supervised learning assumes that we additionally have access to samples from the marginal 
distribution of x, namely x ~ p(x), as illustrated in Figure 19.8. Further, it is generally assumed that 
we have many more of these unlabeled samples since they are typically cheaper to obtain. Continuing 
the example of automatic speech recognition, it is often much cheaper to simply record people talking 
(which would produce unlabeled data) than it is to transcribe recorded speech. Semi-supervised 
learning is a good fit for the scenario where a large amount of unlabeled data has been collected and 
the practitioner would like to avoid having to label all of it. 


19.3.1 Self-training and pseudo-labeling 


An early and straightforward approach to semi-supervised learning is self-training [Scu65; Agr70; 
McL75]. The basic idea behind self-training is to use the model itself to infer predictions on unlabeled 
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data, and then treat these predictions as labels for subsequent training. Self-training has endured 
as a semi-supervised learning method because of its simplicity and general applicability; i.e. it is 
applicable to any model that can generate predictions for the unlabeled data. Recently, it has become 
common to refer to this approach as “pseudo-labeling” [Lee13] because the inferred labels for 
unlabeled data are only “pseudo-correct” in comparison with the true, ground-truth targets used in 
supervised learning. 

Algorithmically, self-training typically follows one of the following two procedures. In the first 
approach, pseudo-labels are first predicted for the entire collection of unlabeled data and the model 
is re-trained (possibly from scratch) to convergence on the combination of the labeled and (pseudo- 
labeled) unlabeled data. Then, the unlabeled data is re-labeled by the model and the process repeats 
itself until a suitable solution is found. The second approach instead continually generates predictions 
on randomly-chosen batches of unlabeled data and immediately trains the model against these 
pseudo-labels. Both approaches are currently common in practice; the first “offline” variant has been 
shown to be particularly successful when leveraging giant collections of unlabeled data [Yal+-19; 
Xie+20] whereas the “online” approach is often used as one component of more sophisticated semi- 
supervised learning methods [Soh+20]. Neither variant is fundamentally better than the other. Offline 
self-training can result in training the model on “stale” pseudo-labels, since they are only updated 
each time the model converges. However, online pseudo-labeling can incur larger computational costs 
since it involves constantly “re-labeling’ unlabeled data. 

Self-training can suffer from an obvious problem: If the model generates incorrect predictions for 
unlabeled data and then is re-trained on these incorrect predictions, it can become progressively 
worse and worse at the intended classification task until it eventually learns a totally invalid solution. 
This issue has been dubbed confirmation bias [TV17] because the model is continually confirming 
its own (incorrect) bias about the decision rule. 

A common way to mitigate confirmation bias is to use a “selection metric” [RHS05] which 
heuristically tries to only retain pseudo-labels that are correct. For example, assuming that a 
model outputs probabilities for each possible class, a frequently-used selection metric is to only retain 
pseudo-labels whose largest class probability is above a threshold [Yar95; RHS05]. If the model’s 
class probability estimates are well-calibrated, then this selection metric will only retain labels that 
are highly likely to be correct (according to the model, at least). More sophisticated selection metrics 
can be designed according to the problem domain. 


19.3.2 Entropy minimization 


Self-training has the implicit effect of encouraging the model to output low-entropy (i.e. high- 
confidence) predictions. This effect is most apparent in the online setting with a cross-entropy loss, 
where the model minimizes the following loss function £ on unlabeled data: 


L = — maxlog po(y = clx) (19.16) 


where po(y|x) is the model’s class probability distribution given input æ. This function is minimized 
when the model assigns all of its class probability to a single class c*, i.e. p(y = c*|a) = 1 and 
ply # œ |x) =0. 

A closely-related semi-supervised learning method is entropy minimization [GB05], which 
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Figure 19.9: Comparison of the entropy minimization, self-training, and “sharpened” entropy minimization 
loss functions for a binary classification problem. 


minimizes the following loss function: 


c 
L=- X poly = cla) log poly = cla) (19.17) 


c=1 


Note that this function is also minimized when the model assigns all of its class probability to a 
single class. We can make the entropy-minimization loss in Equation (19.17) equivalent to the online 
self-training loss in Equation (19.16) by replacing the first pe(y = cla) term with a “one-hot” vector 
that assigns a probability of 1 for the class that was assigned the highest probability. In other words, 
online self-training minimizes the cross-entropy between the model’s output and the “hard” target 
arg max pe(y|a), whereas entropy minimization uses the the “soft” target po(y|a). One way to trade 
off between these two extremes is to adjust the “temperature” of the target distribution by raising 
each probability to the power of 1/T and renormalizing; this is the basis of the mixmatch method 
of [Ber+19b; Ber+19a; Xie+19]. At T = 1, this is equivalent to entropy minimization; as T — 0, it 
becomes hard online self-training. A comparison of these loss functions is shown in Figure 19.9. 


19.3.2.1 The cluster assumption 


Why is entropy minimization a good idea? A basic assumption of many semi-supervised learning 
methods is that the decision boundary between classes should fall in a low-density region of the 
data manifold. This effectively assumes that the data corresponding to different classes are clustered 
together. A good decision boundary, therefore, should not pass through clusters; it should simply 
separate them. Semi-supervised learning methods that make the “cluster assumption” can be 
thought of as using unlabeled data to estimate the shape of the data manifold and moving the 
decision boundary away from it. 

Entropy minimization is one such method. To see why, first assume that the decision boundary 
between two classes is “smooth”, i.e. the model does not abruptly change its class prediction anywhere 
in its domain. This is true in practice for simple and/or regularized models. In this case, if the 
decision boundary passes through a high-density region of data, it will by necessity produce high- 
entropy predictions for some samples from the data distribution. Entropy minimization will therefore 
encourage the model to place its decision boundary in low-density regions of the input space to 
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(a) (b) 


Figure 19.10: Visualization demonstrating how entropy minimization enforces the cluster assumption. The 
classifier assigns a higher probability to class 1 (black dots) or 2 (white dots) in red or blue regions respectively. 
The predicted class probabilities for one particular unlabeled datapoint is shown in the bar plot. In (a), the 
decision boundary passes through high-density regions of data, so the classifier is forced to output high-entropy 
predictions. In (b), the classifier avoids high-density regions and is able to assign low-entropy predictions to 
most of the unlabeled data. 


avoid transitioning from one class to another in a region of space where data may be sampled. A 
visualization of this behavior is shown in Figure 19.10. 


19.3.2.2 Input-output mutual information 


An alternative justification for the entropy minimization objective was proposed by Bridle, Heading, 
and MacKay [BHM92], where it was shown that it naturally arises from maximizing the mutual 
information (Section 6.3) between the data and the label (i.e. the input and output of a model). 
Denoting æ as the input and y as the target, the input-output mutual information can be written as 


x)= yo log FOE) yd (19.18) 


p(y, £) 
=| fe (y|@)p(a) log own ve (19.19) 


an (yla) log UE ay (19.20) 


= lae e PU) 
= | æde | piole) oe ay (19.21) 
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Note that the first integral is equivalent to taking an expectation over x, and the second integral is 
equivalent to summing over all possible values of the class y. Using these relations, we obtain 


— 


: (yil@) 
L(y; x) = Ez B yi|x) log ruin | (19.22) 


Ez[p(yi|x)] 


L L 
= Ey bay yiļæ) pinla) -Ez [Erie ta lp(y:læ)] (19.23) 


i=1 


L 
= Ex pan yi z) arta) - Y Es[pluile) log Ex [p(il@)]] (19.24) 


i= w=1 


Since we had initially sought to maximize the mutual information, and we typically minimize loss 
functions, we can convert this to a suitable loss function by negating it: 


L 
+} > Ee[p(yi|@) log Ex [p(yilx)]] (19.25) 


i=l 


L 
T(y; æ) = -Ez [Z rl got 


i=1 


The first term is exactly the entropy minimization objective in expectation. The second term specifies 
that we should maximize the entropy of the expected class prediction, i.e. the average class prediction 
over our training set. This encourages the model to predict each possible class with equal probability, 
which is only appropriate when we know a priori that all classes are equally likely. 


19.3.3 Co-training 


Co-training [B98] is also similar to self-training, but makes an additional assumption that there 
are two complementary “views” (i.e. independent sets of features) of the data, both of which can 
be used separately to train a reasonable model. After training two models separately on each view, 
unlabeled data is classified by each model to obtain candidate pseudo-labels. If a particular pseudo- 
label receives a low-entropy prediction (indicating high confidence) from one model and a high-entropy 
prediction (indicating low confidence) from the other, then that pseudo-labeled datapoint is added to 
the training set for the low-confidence model. Then, the process is repeated with the new, larger 
training datasets. The procedure of only retaining pseudo-labels when one of the models is confident 
ideally builds up the training sets with correctly-labeled data. 

Co-training makes the strong assumption that there are two informative-but-independent views 
of the data, which may not be true for many problems. The Tri-Training algorithm [ZL05] 
circumvents this issue by instead using three models that are first trained on independently-sampled 
(with replacement) subsets of the labeled data. Ideally, initially training on different collections of 
labeled data results in models that do not always agree on their predictions. Then, pseudo-labels are 
generated for the unlabeled data independently by each of the three models. For a given unlabeled 
datapoint, if two of the models agree on the pseudo-label, it is added to the training set for the 
third model. This can be seen as a selection metric, because it only retains pseudo-labels where 
two (differently initialized) models agree on the correct label. The models are then re-trained on 
the combination of the labeled data and the new pseudo-labels, and the whole process is repeated 
iteratively. 
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19.3.4 Label propagation on graphs 


If two datapoints are “similar” in some meaningful way, we might expect that they share a label. This 
idea has been referred to as the manifold assumption. Label propagation is a semi-supervised 
learning technique that leverages the manifold assumption to assign labels to unlabeled data. Label 
propagation first constructs a graph where the nodes are the data examples and the edge weights 
represent the degree of similarity. The node labels are known for nodes corresponding to labeled data 
but are unknown for unlabeled data. Label propagation then propagates the known labels across 
edges of the graph in such a way that there is minimal disagreement in the labels of a given node’s 
neighbors. This provides label guesses for the unlabeled data, which can then be used in the usual 
way for supervised training of a model. 

More specifically, the basic label propagation algorithm [ZG02] proceeds as follows: First, let 
wi j denote a non-negative edge weight between x; and x; that provides a measure of similarity for 
the two (labeled or unlabeled) datapoints. Assuming that we have M labeled datapoints and N 
unlabeled datapoints, define the (M+ N) x (M+ N) transition matrix T as having entries 


es 
Pep i (19.26) 
T Dp wes 


T; j; represents the probability of propagating the label for node j to node 7. Further, define the 
(M +N) x C label matrix Y, where C is the number of possible classes. The ith row of Y represents 
the class probability distribution of datapoint i. Then, repeat the following steps until the values in 
Y do not change significantly: First, use the transition matrix T to propagate labels in Y by setting 
Y < TY. Then, re-normalize the rows of Y by setting Y; c <— Y; c/p Yip- Finally, replace the 
rows of Y corresponding to labeled datapoints with their one-hot representation (i.e. Y; c = 1 if 
datapoint į has ground-truth label c and 0 otherwise). After convergence, guessed labels are chosen 
based on the highest class probability for each datapoint in Y. 

This algorithm iteratively uses the similarity of datapoints (encoded in the weights used to construct 
the transition matrix) to propagate information from the (fixed) labels onto the unlabeled data. At 
each iteration, the label distribution for a given datapoint is computed as the weighted average of 
the label distributions for all of its connected datapoints, where the weighting corresponds to the 
edge weights in T. It can be shown that this procedure converges to a single fixed point, whose 
computational cost mainly involves the inversion of the matrix of unlabled-to-unlabled transition 
probabilities [ZG02]. 

The overall approach can be seen as a form of transductive learning, since it is learning to 
predict labels for a fixed unlabeled dataset, rather than learning a model that generalizes. However, 
given the induced labeling. we can perform inductive learning in the usual way. 

The success of label propagation depends heavily on the notion of similarity used to construct the 
weights between different nodes (datapoints). For simple data, measuring the Euclidean distance 
between datapoints can be sufficient. However, for complex and high-dimensional data the Euclidean 
distance might not meaningfully reflect the likelihood that two datapoints share the same class. The 
similarity weights can also be set arbitrarily according to problem-specific knowledge. For a few 
examples of different ways of constructing the similarity graph, see Zhu [Zhu05, chapter 3]. For some 
recent papers that use this approach in conjunction with deep learning, see e.g., [BRR18; Isc+19]. 
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19.3.5 Consistency regularization 


Consistency regularization leverages the simple idea that perturbing a given datapoint (or the 
model itself) should not cause the model’s output to change dramatically. Since measuring consistency 
in this way only makes use of the model’s outputs (and not ground-truth labels), it is readily applicable 
to unlabeled data and therefore can be used to create appropriate loss functions for semi-supervised 
learning. This idea was first proposed under the framework of “learning with pseudo-ensembles” 
[BAP 14], with similar variants following soon thereafter [LA16; SJT16]. 

In its most general form, both the model pg(y|x) and the transformations applied to the input 
can be stochastic. For example, in computer vision problems we may transform the input by using 
data augmentation like randomly rotating or adding noise the input image, and the network may 
include stochastic components like dropout (Section 13.5.4) or weight noise [Gral1]. A common and 
simple form of consistency regularization first samples x’ ~ q(a’|x) (where q(a’|x) is the distribution 
induced by the stochastic input transformations) and then minimizes the loss ||p9(y|a) — po(y|a’)|I?. 
In practice, the first term pg(y|x) is typically treated as fixed (i.e. gradients are not propagated 
through it). In the semi-supervised setting, the combined loss function over a batch of labeled data 


(21, Y1), (@2, Y2), ---, (£M, ym) and unlabeled data a1, £2,..., 2y is 
M N 
L(0) =- X log po(y = yilæ:) +S |lpo(ylaez) — polye)? (19.27) 
j=l j=l 


where A is a scalar hyperparameter that balances the importance of the loss on unlabeled data and, 
for simplicity, we write x’, to denote a sample drawn from q(æ'|æ;). 

The basic form of consistency regularization in Equation (19.27) reveals many design choices 
that impact the success of this semi-supervised learning approach. First, the value chosen for the 
A hyperparameter is important. If it is too large, then the model may not give enough weight to 
learning the supervised task and will instead start to reinforce its own bad predictions (as with 
confirmation bias in self-training). Since the model is often poor at the start of training before it has 
been trained on much labeled data, it is common in practice to initialize set to zero and increase 
its value over the course of training. 

A second important consideration are the random transformations applied to the input, i.e., g(a’|a). 
Generally speaking, these transformations should be designed so that they do not change the label 
of x. As mentioned above, a common choice is to use domain-specific data augmentations. It has 
recently been shown that using strong data augmentations that heavily corrupt the input (but, 
arguably, still do not change the label) can produce particularly strong results [Xie+19; Ber+19a; 
Soh+20]. 

The use of data augmentation requires expert knowledge to determine what kinds of transformations 
are label-preserving and appropriate for a given problem. An alternative technique, called virtual 
adversarial training (VAT), instead transforms the input using an analytically-found perturbation 
designed to maximally change the model’s output. Specifically, VAT computes a perturbation 6 that 
approximates 6 = argmaxs Dx (po(y|x) || pe(y|@ + 4)). The approximation is done by sampling d 
from a multivariate Gaussian distribution, initializing 6 = d, and then setting 


ô — Vs5Dxu (po(y|x) || polylæ + 5))|5=ea (19.28) 
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Figure 19.11: Comparison of the squared error and KL divergence lossses for a consistency regularization. 
This visualization is for a binary classification problem where it is assumed that the model’s output for 
the unperturbed input is 1. The figure plots the loss incurred for a particular value of the logit (i.e. the 
pre-activation fed into the output sigmoid nonlinearity) for the perturbed input. As the logit grows towards 
infinity, the model predicts a class label of 1 (in agreement with the prediction for the unperturbed input); as 
it grows towards negative infinity, the model predictions class 0. The squared error loss saturates (and has 
zero gradients) when the model predicts one class or the other with high probability, but the KL divergence 
grows without bound as the model predicts class 0 with more and more confidence. 


where £ is a small constant, typically 10~°. VAT then sets 
gee * (19.29) 
I|S|l2 


and proceeds as usual with consistency regularization (as in Equation (19.27)), where € is a scalar 
hyperparameter that sets the L2-norm of the perturbation applied to a. 

Consistency regularization can also profoundly affect the geometry properties of the training 
objective, and the trajectory of SGD, such that performance can particularly benefit from non- 
standard training procedures. For example, the Euclidean distances between weights at different 
training epochs is significantly larger for objectives that use consistency regularization. Athiwaratkun 
et al. [Ath+19] show that a variant of stochastic weight averaging (SWA) [Izm-+18] can achieve 
state-of-the-art performance on semi-supervised learning tasks by exploiting the geometric properties 
of consistency regularization. 

A final consideration when using consistency regularization is the function used to measure the 
difference between the network’s output with and without perturbations. Equation (19.27) uses the 
squared L2 distance (also referred to as the Brier score), which is a common choice [SJT16; TV17; 
LA16; Ber+19b]. It is also common to use the KL divergence Dpi (po(y|x) || pe(y|x’) in analogy 
with the cross-entropy loss (i.e. KL divergence between ground-truth label and prediction) used for 
labeled examples [Miy+18; Ber+19a; Xie+19]. The gradient of the squared-error loss approaches zero 
as the model’s predictions on the perturbed and unperturbed input differ more and more, assuming 
the model uses a softmax nonlinearity on its output. Using the squared-error loss therefore has a 
possible advantage that the model is not updated when its predictions are very unstable. However, 
the KL divergence has the same scale as the cross-entropy loss used for labeled data, which makes for 
more intuitive tuning of the unlabeled loss hyperparameter A. A comparison of the two loss functions 
is shown in Figure 19.11. 
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19.3.6 Deep generative models * 


Generative models provide a natural way of making use of unlabeled data through learning a model of 
the marginal distribution by minimizing Ly = — }_„ log pe(x,,). Various approaches have leveraged 
generative models for semi-supervised by developing ways to use the model of pọ (£n) to help produce 
a better supervised model. 


19.3.6.1 Variational autoencoders 


In Section 20.3.5, we describe the variational autoencoder (VAE), which defines a probabilistic model 
of the joint distribution of data x and latent variables z. Data is assumed to be generated by first 
sampling z ~ p(z) and then sampling x ~ p(a|z). For learning, the VAE uses an encoder q)(z|x) 
to approximate the posterior and a decoder pg(x|z) to approximate the likelihood. The encoder 
and decoder are typically deep neural networks. The parameters of the encoder and decoder can be 
jointly trained by maximizing the evidence lower bound (ELBO) of data. 

The marginal distribution of latent variables p(z) is often chosen to be a simple distribution like 
a diagonal-covariance Gaussian. In practice, this can make the latent variables z more amenable 
to downstream classification thanks to the facts that z is typically lower-dimensional than æ, that 
z is constructed via cascaded nonlinear transformations, and that the dimensions of the latent 
variables are designed to be independent. In other words, the latent variables can provide a (learned) 
representation where data may be more easily separable. In [Kin+14], this approach is called M1 
and it is indeed shown that the latent variables can be used to train stronger models when labels 
are scarce. (The general idea of unsupervised learning of representations to help with downstream 
classification tasks is described further in Section 19.2.4.) 

An alternative approach to leveraging VAEs, also proposed in [Kin+14] and called M2, has the 
form 


po(æ, y) = po(y)pe(ly) = poly) / po(ely, z)po(z)dz (19.30) 


where z is a latent variable, pg(z) = N (z|4o, He) is the latent prior (typically we fix pọ = 0 and 
Xo = I), pe(y) = Cat(y|me@) the label prior, and pe(aly, z) = p(x|fe(y, z)) is the likelihood, such as 
a Gaussian, with parameters computed by f (a deep neural network). The main innovation of this 
approach is to assume that data is generated according to both a latent class variable y as well as 
the continuous latent variable z. The class variable y is observed for labeled data and unobserved for 
unlabled data. 

To compute the likelihood for the labeled data, pe(x, y), we need to marginalize over z, which we 
can do by using an inference network of the form 


qo(zly, x) = N (z|o (Y, x), diag(og (x) (19.31) 


We then use the following variational lower bound 


log po (x,y) > Eqy(z\x,y) [log pe(xly, z) + log po (y) + log pe(z) — log qg(z|x, y)] = —L(x,y) (19.32) 


as is standard for VAEs (see Section 20.3.5). The only difference is that we observe two kinds of 
data: x and y. 
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To compute the likelihood for the unlabeled data, pọ(x), we need to marginalize over z and y, 
which we can do by using an inference network of the form 


qelz, yl@) = qg (z|x)qgo (ule) (19.33) 
qo(z|æ) = N (z|o (£), diag(og (a) (19.34) 
qo(y|@) = Cat(y|mg(a)) (19.35) 


Note that q¢(y|x) acts like a discriminative classifier, that imputes the missing labels. We then use 
the following variational lower bound: 


log po (x) > Eq y(z,y\x) [log pe(xly, z) + log po (y) + log pe(z) — log go(z, ylx)] (19.36) 
=— X qglulæ) L(x, y) +H (ae(ylx)) = -U (2) (19.37) 


Note that the discriminative classifier qg(y|a) is only used to compute the log-likelihood of the 
unlabeled data, which is undesirable. We can therefore add an extra classification loss on the 
supervised data, to get the following overall objective function: 


L0) = Ea,y)~d, [LE y)] + Eewdy U(#)| + aE) [~ log qo (yl@)] (19.38) 


where a is a hyperparameter that controls the relative weight of generative and discriminative 
learning. 

Of course, the probablistic model used in M2 is just one of many ways to decompose the dependencies 
between the observed data, the class labels, and the continuous latent variables. There are also many 
ways other than variational inference to perform approximate inference. The best technique will 
be problem dependent, but overall the main advantage of the generative approach is that we can 
incorporate domain knowledge. For example, we can model the missing data mechanism, since the 
absence of a label may be informative about the underlying data (e.g., people may be reluctant to 
answer a survey question about their health if they are unwell). 


19.3.6.2 Generative adversarial networks 


Generative adversarial networks (GANs) (described in more detail in the sequel to this book, 
[Mur23]) are a popular class of generative models that learn an implicit model of the data distribution. 
They consist of a generator network, which maps samples from a simple latent distribution to the 
data space, and a critic network, which attempts to distinguish between the outputs of the generator 
and samples from the true data distribution. The generator is trained to generate samples that the 
critic classifies as “real”. 

Since standard GANs do not produce a learned latent representation of a given datapoint and do 
not learn an explicit model of the data distribution, we cannot use the same approaches as were used 
for VAEs. Instead, semi-supervised learning with GANSs is typically done by modifying the critic 
so that it outputs either a class label or “fake” instead of simply classifying real vs. fake [Sal+-16; 
Odel16]. For labeled real data, the critic is trained to output the appropriate class label, and for 
unlabeled real data, it is trained to raise the probability of any of the class labels. As with standard 
GAN training, the critic is trained to classify outputs from the generator as fake and the generator is 
trained to fool the critic. 
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Labeled image, class 2 


Figure 19.12: Diagram of the semi-supervised GAN framework. The discriminator is trained to output the 
class of labeled datapoints (red), a “fake” label for outputs from the generator (yellow), and any label for 
unlabeled data (green). 


In more detail, let pọ(y|x) denote the critic with C + 1 outputs corresponding to C classes plus a 
“fake” class, and let G(z) denote the generator which takes as input samples from the prior distribution 
p(z). Let us assume that we are using the standard cross-entropy GAN loss as originally proposed in 
[Goo+14]. Then the critic’s loss is 


= Ln yop(w,y) log po(y|x) i Tarao) log[1 — po (y = C +1|æ)] T Lz~wp(z) log po(y =C+ 1|G(z)) (19.39) 


This tries to maximize the probability of the correct class for the labeled examples, to minimize the 
probability of the fake class for real unlabeled examples, and to maximize the probability of the fake 
class for generated examples. The generator’s loss is simpler, namely 


Sz~p(z) log pa(y = C + 1|G(z)) (19.40) 


A diagram visualizing the semi-supervised GAN framework is shown in Figure 19.12. 


19.3.6.3 Normalizing flows 


Normalizing flows (described in more detail in the sequel to this book, [Mur23]) are a tractable 
way to define deep generative models. More precisely, they define an invertible mapping fg: ¥ > Z, 
with parameters 0, from the data space ¥ to the latent space Z. The density in data space can be 
written starting from the density in the latent space using the change of variables formula: 


pte) = p(o) fact (34). (19.41) 


We can extend this to semi-supervised learning, as proposed in [Izm-+20]. For class labels 
y € {1...C}, we can specify the latent distribution, conditioned on a label k, as Gaussian with mean 
uk and covariance Ux: p(z|y = k) = N (z|uk, Ux). The marginal distribution of z is then a Gaussian 
mixture. The likelihood for labeled data is then 


of 
det (54) | (19.42) 
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Figure 19.18: Combining self-supervised learning on unlabeled data (left), supervised fine-tuning (middle), 
and self-training on pseudo-labeled data (right). From Figure 3 of [Che+20c]. Used with kind permission of 
Ting Chen. 


and the likelihood for data with unknown label is p(x) = 5°, p(x|y = k)p(y = k). 
For semi-supervised learning we can then maximize the joint likelihood of the labeled Dy and 
unlabeled data D,: 


P(De,Dul9) = TY vlaiys) [] vl(os), (19.43) 
(xi,yi)EDe zjEDu 


over the parameters 0 of the bijective function f, which learns a density model for a Bayes classifier. 
Given a test point x, the model predictive distribution is given by 


ee plzly = c)ply =) ____ plaly=c)p(y = 0) N (F(£)|uc, Xo) 


p(z) EC plaly = kplu =k) ELN GE) lites De)’ 
(19.44) 


where we have assumed p(y = c) = 1/C. We can make predictions for a test point x with the Bayes 
decision rule y = arg maxce{1,... c} p(y = ¢|@). 


19.3.7 Combining self-supervised and semi-supervised learning 


It is possible to combine self-supervised and semi-supervised learning. For example, [Che+20c] 
use SimCLR (Section 19.2.4.4) to perform self-supervised representation learning on the unlabeled 
data, they then fine-tune this representation on a small labeled dataset (as in transfer learning, 
Section 19.2), and finally, they apply the trained model back to the original unlabeled dataset, and 
distill the predictions from this teacher model T into a student model S. (Knowledge distillation 
is the name given to the approach of training one model on the predictions of another, as originally 
proposed in [HVD14].) That is, after fine-tuning T, they train S by minimizing 


LIT) =— X | > p (yle; rT) logp (ylei rT) (19.45) 


siED L y 
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where 7 > 0 is a temperature parameter applied to the softmax output, which is used to perform 
label smoothing. If S has the same form as T, this is known as self-training, as discussed in 
Section 19.3.1. However, normally the student S' is smaller than the teacher T. (For example, T might 
be a high capacity model, and S is a lightweight version that runs on a phone.) See Figure 19.13 for 
an illustration of the overall approach. 


19.4 Active learning 


In active learning, the goal is to identify the true predictive mapping y = f(a) by querying as few 
(x,y) points as possible. There are three main variants. In query synthesis, the algorithm gets 
to choose any input a, and can ask for its corresponding output y = f(x). In pool-based active 
learning, there is a large, but fixed, set of unlabeled data points, and the algorithm gets to ask for 
a label for one or more of these points. Finally, in stream-based active learning, the incoming 
data is arriving continuously, and the algorithm must choose whether it wants to request a label for 
the current input or not. 

There are various closely related problems. In Bayesian optimization the goal is to estimate 
the location of the global optimum «* = argmin,, f(a) in as few queries as possible; typically we fit 
a surrogate (response surface) model to the intermediate (x,y) queries, to decide which question 
to ask next. In experiment design, the goal is to infer a parameter vector of some model, using 
carefully chosen data samples D = {a1,...,xy}, ie. we want to estimate p(@|D) using as little data 
as possible. (This can be thought of as an unsupervised, or generalized, form of active learning.) 

In this section, we give a brief review of the pool based approach to active learning. For more 
details, see e.g., [Set 12] for a review. 


19.4.1 Decision-theoretic approach 


In the decision theoretic approach to active learning, proposed in [KHB07; RMO1], we define the 
utility of querying xz in terms of the value of information. In particular, we define the utility of 
issuing query £ as 


U(x) = Ex(ylx,D) [min (p(alD) — plalD, (x, y))) (19.46) 


where p(a|D) = Epon) [E(0, a)] is the posterior expected loss of taking some future action a given the 
data D observed so far. Unfortunately, evaluating U(x) for each x is quite expensive, since for each 
possible response y we might observe, we have to update our beliefs given (a, y) to see what effect it 
might have on our future decisions (similar to look ahead search technique applied to belief states). 


19.4.2 Information-theoretic approach 


In the information theoretic approach to active supervised learning, we avoid using task-specific loss 
functions, and instead focus on learning our model as well as we can. In particular, [Lin56] proposed 
to define the utility of querying x in terms of information gain about the parameters 0, i.e., the 
reduction in entropy: 


U(x) = H (p(6|D)) — E,n(yje,p) [H (p(6|D, æ, y))] (19.47) 
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(Note that the first term is a constant wrt æ, but we include it for later convenience.) Exercise 19.1 
asks you to show that this objective is identical to the expected change in the posterior over the 
parameters which is given by 


U' (a) = En(yle,p) [Dex (POD, æ, y) || p(O|P))] (19.48) 


Using symmetry of the mutual information, we can rewrite Equation (19.47) as follows: 


U(x) = H (p(8|D)) a Un(y|a,D) [HI (p(O|D, x, y))] (19.49) 
~1(6,y|D, x) (19.50) 
= H (p(y|æ, D)) — Epon) H (p(y, 0))] (19.51) 


The advantage of this approach is that we now only have to reason about the uncertainty of the 
predictive distribution over outputs y, not over the parameters 0. 

Equation (19.51) has an interesting interpretation. The first term prefers examples æ for which 
there is uncertainty in the predicted label. Just using this as a selection criterion is called maximum 
entropy sampling [SW87]. However, this can have problems with examples which are inherently 
ambiguous or mislabeled. The second term in Equation (19.51) will discourage such behavior, since 
it prefers examples x for which the predicted label is fairly certain once we know 0; this will avoid 
picking inherently hard-to-predict examples. In other words, Equation (19.51) will select examples 
x for which the model makes confident predictions which are highly diverse. This approach has 
therefore been called Bayesian active learning by disagreement or BALD |Hou+12]. 

This method can be used to train classifiers for other domains where expert labels are hard to 
acquire, such as medical images or astronomical images [Wal-+20]. 


19.4.3 Batch active learning 


So far, we have assumed a greedy or myopic strategy, in which we select a single example æ, 
as if it were the last datapoint to be selected. But sometimes we have a budget to collect a set 
of B samples, call them (X,Y). In this case, the information gain criterion becomes U(x) = 
H (p(@|D)) — E,cy|a,p) [H (p(@/¥Y, £, D))]. Unfortunately, optimizing this is NP-hard in the horizon 
length B [KLQ95; KGO05]. 

Fortunately, under certain conditions, the greedy strategy is near-optimal, as we now explain. Let 
us fix query x and define f(y) = H(p(@|D)) — H(p(@|Y,a,D)) as the information gain function, so 
U(x) = Ey [f(y,«)]. It is clear that f(@) = 0, and that f is non-decreasing, meaning f(viaree) > 


f (ysmall) due to the “more information never hurts” principle. Furthermore, [KGO05] proved that 
f issubmodular. As a consequence, a sequential greedy approach is within a constant factor of 
optimal. If we combine this greedy technique with the BALD objective, we get a method called 
BatchBALD [KAG19]. 


19.5 Meta-learning 
We can think of a learning algorithm as a function A that maps data to a parameter estimate, 


0 = A(D). The function A usually has its own parameter — call them ¢ — such as the initial values 
for 0, or the learning rate, etc. We denote this by 6 = A(D;¢). We can imagine learning ¢ itself, 
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Figure 19.14: Illustration of a hierarchical Bayesian model for meta-learning. Generated by 


hbayes_ maml.ipynb. 


given a collection of datasets D,.; and some meta-learning algorithm M, i.e., 6 = M(Di.7). We 
can then apply A(-;¢) to learn the parameters 67;,; on some new dataset D y+. There are many 
techniques for meta-learning — see e.g., [Van18; HRP21] for recent reviews. Below we discuss one 
particularly popular method. (Note that meta-learning is also called learning to learn [TP97].) 


19.5.1 Model-agnostic meta-learning (MAML) 


A natural approach to meta learning is to use a hierarchical Bayesian model, as illustrated in 
Figure 19.14. The parameters for each task 0; are assumed to come from a common prior, p(@;|&), 
which can be used to help pool statistical strength from multiple data-poor problems. Meta-learning 
becomes equivalent to learning the prior ¢. Rather than performing full Bayesian inference in 
this model, a more efficient approach is to use the following empirical Bayes (Section 4.6.5.3) 
approximation: 


J 
= ween 5 5 log p(Diana lô; (E, Pirain)) (19.52) 
j=1 
where 6; = A(é ; Di ain) is a point estimate of the parameters for task j based on Di and prior &, 
and where we use a cross-validation approximation to the marginal likelihood (Section 5.2.4). 

To compute the point estimate of the parameters for the target task 6 J+1, we use K steps of a 
gradient ascent procedure starting at € with a learning rate of 7. This is known as model-agnostic 
meta-learning or MAML [FAL17]. This can be shown to be equivalent to an approximate MAP 
estimate using a Gaussian prior centered at €, where the strength of the prior is controlled by the 
number of gradient steps [San96; Gra+18]. (This is an example of fast adapation of the task 
specific weights starting from the shared prior €.) 
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Figure 19.15: Illustration of meta-learning for few-shot learning. Here, each task is a 3-way-2-shot classification 
problem because each training task contains a support set with three classes, each with two examples. From 
https: //bit. ly/3rruSjw. Copyright (2019) Borealis AI. Used with kind permission of Simon Prince and 
April Cooper. 


19.6 Few-shot learning 


People can learn to predict from very few labeled examples. This is called few-shot learning. In 
the extreme in which the person or system learns from a single example of each class, this is called 
one-shot learning, and if no labeled examples are given, it is called zero-shot learning. 

A common way to evaluate methods for FSL is to use C-way N-shot classification, in which 
the system is expected to learn to classify C classes using just N training examples of each class. 
Typically N and C are very small, e.g., Figure 19.15 illustrates the case where we have C = 3 classes, 
each with N = 2 examples. Since the amount of data from the new domain (here, ducks, dolphins 
and hens) is so small, we cannot expect to learn from scratch. Therefore we turn to meta-learning. 

During training, the meta-algorithm M trains on a labeled support set from group j, returns a 
predictor fÍ, which is then evaluated on a disjoint query set also from group j. We optimize M over 
all J groups. Finally we can apply M to our new labeled support set to get f***t, which is applied 
to the query set from the test domain. This is illustrated in Figure 19.15. We see that there is no 
overlap between the classes in the two training tasks ({cat, lamb, pig} and {dog, shark, lion}) and 
those in the test task ({duck, dolphin, hen}). Thus the algorithm M must learn to predict image 
classes in general rather than any particular set of labels. 

There are many approaches to few-shot learning. We discuss one such method in Section 19.6.1. 
For more methods, see e.g., [Wan-+20b]. 


19.6.1 Matching networks 


One approach to few shot learning is to learn a distance metric on some other dataset, and then to 
use dg(x, x’) inside of a nearest neighbor classifier. Essentially this defines a semi-parametric model 
of the form pe(y|x,S), where S is the small labeled dataset (known as the support set), and 0 are the 
parameters of the distance function. This approach is widely used for fine-grained classification 
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Figure 19.16: Illustration of a matching network for one-shot learning. From Figure 1 of [Vin+16]. Used 
with kind permission of Oriol Vinyals. 


tasks, where there are many different visually similar categories, such as face images from a gallery, 
or product images from a catalog. 
An extension of this approach is to learn a function of the form 


po(y|x, S) =I (, = 5 cole: Sn) (19.53) 


nes 


where ag(£, £n; S) € Rt is some kind of adaptive similarity kernel. For example, we can use an 
attention kernel of the form 


SDE) Ilen) i 
Dra p(l (a), g(@n)) 


where c(u,v) is the cosine distance. (We can make f and g be the same function if we want.) 
Intuitively, the attention kernel will compare æ to x, in the context of all the labeled examples, 
which provides an implicit signal about which feature dimensions are relevant. (We discuss attention 
mechanisms in more detail in Section 15.4.) This is called a matching network [Vin+16]. See 
Figure 19.16 for an illustration. 

We can train the f and g functions using multiple small datasets, as in meta-learning (Section 19.5). 
More precisely, let D be a large labeled dataset (e.g., ImageNet), and let p(£) be a distribution over 
its labels. We create a task by sampling a small set of labels (say 25), £ ~ p(L), and then sampling a 
small support set of examples from D with those labels, S ~ £, and finally sampling a small test set 
with those same labels, 7 ~ £. We then train the model to predict the test labels given the support 
set, i.e., we optimize the following objective: 


a(x, £n; S) = 


L(0; D) = Erny) |Es~c.r~c | X logpolylæ, S) (19.55) 
(æ,y)ET 


After training, we freeze 0, and apply Equation (19.53) to a test support set S. 
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19.7 Weakly supervised learning 


The term weakly supervised learning refers to scenarios where we do not have an exact label 
associated with every feature vector in the training set. 

One scenario is when we have a distribution over labels for each case, rather than a single label. 
Fortunately, we can still do maximum likelihood training: we just have to minimize the cross entropy, 


L(0)=-Y_ X p(ylan) log qo(ylan) (19.56) 


where p(y|@»,) is the label distribution for case n, and ge(y|an) is the predicted distribution. Indeed, 
it is often useful to artificially replace exact labels with a “soft” version, in which we replace the delta 
function with a distribution that puts, say, 90% of its mass on the observed label, and spreads the 
remaining mass uniformly over the other choices. This is called label smoothing, and is a useful 
form of regularization (see e.g., [MKH19]). 

Another scenario is when we have a set, or bag, of instances, £n = {@n,1,..-,%n,B}, but we only 
have a label for the entire bag, yn, not for the members of the bag, yny. We often assume that if any 
member of the bag is positive, the whole bag is labeled positive, so yn = VE_,Ynp, but we do not 
know which member “caused” the positive outcome. However, if all the members are negative, the 
entire bag is negative. This is known as multi-instance learning [DLLP97]. (For a recent example 
of this in the context of COVID-19 risk score learning, see [MKS21].) Various algorthms can be used 
to solve the MIL problem, depending on what assumptions we make about the correlation between 
the labels in each bag, and the fraction of positive members we expect to see (see e.g., [KF05]). 

Yet another scenario is known as distant supervision [Min+09], which is often used to train 
information extraction systems. The idea is that we have some fact, such as “Married(B,M)”, that 
we know to be true (since it is stored in a database). We use this to label every sentence (in our 
unlabeled training corpus) in which the entities B and M are mentioned as being a positive example 
of the “Married” relation. For example, the sentence “B and M invited 100 people to their wedding’ 
will be labeled positive. But this heuristic may include false positives, for example “B and M went 
out to dinner” will also be labeled positive. Thus the resulting labels will be noisy. We discuss some 
ways to handle label noise in Section 10.4. 


? 


19.8 Exercises 


Exercise 19.1 [Information gain equations] 


Consider the following two objectives for evaluating the utility of querying a datapoint æ in an active learning 
setting: 


U(x) £ H (p(9{D)) — Ep(y|x,D) H (p(8|D, æ, y))] (19.57) 
U' (x) Ê Epcy|e,p) [Dux (p(O|D, æ, y) || p(@|P))] (19.58) 


Prove that these are equal. 
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20 Dimensionality Reduction 


A common form of unsupervised learning is dimensionality reduction, in which we learn a mapping 
from the high-dimensional visible space, x € R?, to a low-dimensional latent space, z € R’. This 
mapping can either be a parametric model z = f(x; 0) which can be applied to any input, or it can 
be a nonparametric mapping where we compute an embedding zn for each input x, in the data 
set, but not for any other points. This latter approach is mostly used for data visualization, whereas 
the former approach can also be used as a preprocessing step for other kinds of learning algorithms. 
For example, we might first reduce the dimensionality by learning a mapping from æ to z, and then 
learn a simple linear classifier on this embedding, by mapping z to y. 


20.1 Principal components analysis (PCA) 


The simplest and most widely used form of dimensionality reduction is principal components 
analysis or PCA. The basic idea is to find a linear and orthogonal projection of the high dimensional 
data x € R? to a low dimensional subspace z € R”, such that the low dimensional representation is 
a “good approximation” to the original data, in the following sense: if we project or encode =z to get 
z = W! z, and then unproject or decode z to get & = Wz, then we want # to be close to x in l2 
distance. In particular, we can define the following reconstruction error or distortion: 


N 
1 
L(W) ê Ny XC ||an — decode(encode(x,; W); W)||5 (20.1) 
n=l 
where the encode and decoding stages are both linear maps, as we explain below. 
In Section 20.1.2, we show that we can minimize this objective by setting W = Uz, where Uz 
contains the L eigenvectors with largest eigenvalues of the empirical covariance matrix 


N 
A na 1 _ AT a 1 T 
y= N den T)(£n- T) = WXX (20.2) 


where X, is a centered version of the N x D design matrix. In Section 20.2.2, we show that this is 


equivalent to maximizing the likelihood of a latent linear Gaussian model known as probabilistic 
PCA. 


20.1.1 Examples 


Before giving the details, we start by showing some examples. 
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Figure 20.1: An illustration of PCA where we project from 2d to 1d. Red circles are the original data points, 
blue circles are the reconstructions. The red dot is the data mean. Generated by pcaDemoZ2d.ipynb. 


Second Principal Component 


First Principal Component 


Figure 20.2: An illustration of PCA applied to MNIST digits from class 9. Grid points are at the 5, 25, 50, 
75, 95 % quantiles of the data distribution along each dimension. The circled points are the closest projected 
images to the vertices of the grid. Adapted from Figure 14.23 of [HTF09]. Generated by pca_ digits.ipynb. 


Figure 20.1 shows a very simple example, where we project 2d data to a 1d line. This direction 
captures most of the variation in the data. 

In Figure 20.2, we show what happens when we project some MNIST images of the digit 9 down to 
2d. Although the inputs are high dimensional (specifically 28 x 28 = 784 dimensional), the number 
of “effective degrees of freedom” is much less, since the pixels are correlated, and many digits look 
similar. Therefore we can represent each image as a point in a low dimensional linear space. 

In general, it can be hard to interpret the latent dimensions to which the data is projected. However, 
by looking at several projected points along a given direction, and the examples from which they 
are derived, we see that the first principal component (horizontal direction) seems to capture the 
orientation of the digit, and the second component (vertical direction) seems to capture line thickness. 

In Figure 20.3, we show PCA applied to another image dataset, known as the Olivetti face dataset, 
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Figure 20.8: a) Some randomly chosen 64 x 64 pixel images from the Olivetti face database. (b) The mean 
and the first three PCA components represented as images. Generated by pcalmageDemo.ipynb. 


which is a set of 64 x 64 grayscale images. We project these to a 3d subspace. The resulting basis 
vectors (columns of the projection matrix W) are shown as images in in Figure 20.3b; these are 
known as eigenfaces [Tur13], for reasons that will be explained in Section 20.1.2. We see that the 
main modes of variation in the data are related to overall lighting, and then differences in the eyebrow 
region of the face. If we use enough dimensions (but fewer than the 4096 we started with), we can 
use the representation z = W! æ as input to a nearest-neighbor classifier to perform face recognition; 
this is faster and more reliable than working in pixel space [MWP98]. 


20.1.2 Derivation of the algorithm 


Suppose we have an (unlabeled) dataset D = {æn : n = 1 : N}, where æ, € RP. We can represent 
this as an N x D data matrix X. We will assume % = a ae £n = 0, which we can ensure by 
centering the data. 

We would like to approximate each x, by a low dimensional representation, Zn € R’. We assume 
that each x, can be “explained” in terms of a weighted combination of basis functions wy ,..., wz, 
where each wg € R?, and where the weights are given by zn € RY”, i.e., we assume £p © en ZnkWk. 
The vector Zn is the low dimensional representation of £n, and is known as the latent vector, since 
it consists of latent or “hidden” values that are not observed in the data. The collection of these 
latent variables are called the latent factors. 

We can measure the error produced by this approximation as follows: 


N 
1 1 1 
£(W,Z) = Ņ lX- ZW' |b = lX - WZ' ||} = N Yellen — Wznll? (20.3) 


n=1 


where the rows of Z contain the low dimension versions of the rows of X. This is known as the 
(average) reconstruction error, since we are approximating each £n by ĉn = Wzn. 

We want to minimize this subject to the constraint that W is an orthogonal matrix. Below we 
show that the optimal solution is obtained by setting W = Uz, where U; contains the L eigenvectors 
with largest eigenvalues of the empirical covariance matrix. 
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20.1.2.1 Base case 


Let us start by estimating the best 1d solution, wı € RP. We will find the remaining basis vectors 
W2, w3, etc. later. 
Let the coefficients for each of the data points associated with the first basis vector be denoted by 


Ži = [211,---,2ni] E RY. The reconstruction error is given by 
ix 
L£(w1,21) = ad \|@n — 2n1W ||? =J 2 (En E 2n1W1)' (an — Zn1W1) (20.4) 


[xl En — 221 W En + z21w w1] (20.5) 


Í 
2|- 
van 


3 
Il 
un 


[wl En — 2znı Ww En + 274] (20.6) 


I 
z|= 
Mz 


Il 
un 


n 


since wlw = 1 (by the orthonormality assumption). Taking derivatives wrt zņnı and equating to 
zero gives 
z 1 T T 
—— L(wy,Z1) = —[-2w, En + 2zn1] = 0 > Znı = Wi En (20.7) 
Ozn1 N 


So the optimal embedding is obtained by orthogonally projecting the data onto w (see Figure 20.1). 
Plugging this back in gives the loss for the weights: 


N N 
= 1 1 
L(w i) = L(w1, Zi (wi)) = N X eee — 22,] = const — W Ce (20.8) 
n=1 n=l 
To solve for w1, note that 
N N 
a= <> wi eng wi = -wl dw, (20.9) 
n=1 n=1 
where X is the empirical covariance matrix (since we assumed the data is centered). We can trivially 
optimize this by letting ||w || > co, so we impose the constraint ||w || = 1 and instead optimize 
(wi) = wl Sw, — à (wlw; — 1) (20.10) 


where à; is a Lagrange multiplier (see Section 8.5.1). Taking derivatives and equating to zero we 
have 


ô ~x Š 


Èw = iw} (20.12) 


Hence the optimal direction onto which we should project the data is an eigenvector of the covariance 
matrix. Left multiplying by w! (and using wlw = 1) we find 


wl dw, = Ài (20.13) 
Since we want to maximize this quantity (minimize the loss), we pick the eigenvector which corresponds 


to the largest eigenvalue. 
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Figure 20.4: Illustration of the variance of the points projected onto different 1d vectors. vı is the first 
principal component, which maximizes the variance of the projection. v2 is the second principal component 
which is direction orthogonal to vı. Finally v' is some other vector in between vı and v2. Adapted from 
Figure 8.7 of [Gér19]. Generated by pca_ projected_ variance.ipynb 


20.1.2.2 Optimal weight vector maximizes the variance of the projected data 


Before continuing, we make an interesting observation. Since the data has been centered, we have 


D [2n1] = E [al wi] = E [æn] wi = 0 (20.14) 


Hence variance of the projected data is given by 


N 
V [z] = E [27] — (E[%])? = 7 5 22, -0=—L(w 1) + const (20.15) 


From this, we see that minimizing the reconstruction error is equivalent to maximizing the variance 
of the projected data: 


arg min £(w ,) = arg max Y [Z(w))| (20.16) 


This is why it is often said that PCA finds the directions of maximal variance. (See Figure 20.4 for an 
illustration.) However, the minimum error formulation is easier to understand and is more general. 


20.1.2.3 Induction step 


Now let us find another direction wz to further minimize the reconstruction error, subject to 


wlw = 0 and whw2 = 1. The error is 


N 
z = 1 
L(w1,Ž1, W2,Z2) = N 5 [|En — 2n1W1 — zn2w2||? (20.17) 


n=1 
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Figure 20.5: Effect of standardization on PCA applied to the height/weight dataset. (Red=female, blue=male.) 
Left: PCA of raw data. Right: PCA of standardized data. Generated by pcaStandardization.ipynb. 


Optimizing wrt w and zı gives the same solution as before. Exercise 20.3 asks you to show that 


2 = 0 yields zp = Won. Substituting in yields 
N 
L(w2) = N Siecle, — wl gng w — wl geng] w] = const — wy Uw» (20.18) 
n=l 


Dropping the constant term, plugging in the optimal w and adding the constraints yields 


L(we) = —wh Swe + Az(w1 we — 1) + àz(w1 w: — 0) (20.19) 


Exercise 20.3 asks you to show that the solution is given by the eigenvector with the second largest 
eigenvalue: 


Èw = àw (20.20) 


The proof continues in this way to show that W = U;. 


20.1.3 Computational issues 


In this section, we discuss various practical issues related to using PCA. 


20.1.3.1 Covariance matrix vs correlation matrix 


We have been working with the eigendecomposition of the covariance matrix. However, it is better to 
use the correlation matrix instead. The reason is that otherwise PCA can be “misled” by directions in 
which the variance is high merely because of the measurement scale. Figure 20.5 shows an example 
of this. On the left, we see that the vertical axis uses a larger range than the horizontal axis. This 
results in a first principal component that looks somewhat “unnatural”. On the right, we show the 
results of PCA after standardizing the data (which is equivalent to using the correlation matrix 
instead of the covariance matrix); the results look much better. 
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20.1.3.2 Dealing with high-dimensional data 


We have presented PCA as the problem of finding the eigenvectors of the D x D covariance matrix 
XTX. If D >N, it is faster to work with the N x N Gram matrix XX'. We now show how to do 
this. 

First, let U be an orthogonal matrix containing the eigenvectors of XXT with corresponding 
eigenvalues in A. By definition we have (XX')U = UA. Pre-multiplying by XT gives 


(X'X)(X'U) = (XTU)A (20.21) 


from which we see that the eigenvectors of X'X are V = X'U, with eigenvalues given by A as 
before. However, these eigenvectors are not normalized, since ||v;||? = TXX'uj = Ajuj Uj = Aj. 
The normalized eigenvectors are given by 


V = XUA? (20.22) 


This provides an alternative way to compute the PCA basis. It also allows us to use the kernel trick, 
as we discuss in Section 20.4.6. 


20.1.3.3 Computing PCA using SVD 


In this section, we show the equivalence between PCA as computed using eigenvector methods 
(Section 20.1) and the truncated SVD." 

Let UsAyU§ be the top L eigendecomposition of the covariance matrix © x XTX (we assume X 
is centered). Recall from Section 20.1.2 that the optimal estimate of the projection weights W is 
given by the top L eigenvalues, so W = Uy. 

Now let UxSxV\ ~ X be the L-truncated SVD approximation to the data matrix X. From 
Equation (7.184), we know that the right singular vectors of X are the eigenvectors of X™X, so 
Vx = Uy = W. (In addition, the eigenvalues of the covariance matrix are related to the singular 
values of the data matrix via Ax = s?/N.) 

Now suppose we are interested in the projected points (also called the principal components or PC 
scores), rather than the projection matrix. We have 


Z= XW = UxSxV\Vx = UxSx (20.23) 
Finally, if we want to approximately reconstruct the data, we have 
X = ZW" = UxSxV\ (20.24) 


This is precisely the same as a truncated SVD approximation (Section 7.5.5). 

Thus we see that we can perform PCA either using an eigendecomposition of © or an SVD 
decomposition of X. The latter is often preferable, for computational reasons. For very high 
dimensional problems, we can use a randomized SVD algorithm, see e.g., [HMT11; SKT14; DM 16]. 
For example, the randomized solver used by sklearn takes O(N L?) + O(L?) time for N examples 
and L principal components, whereas exact SVD takes O(N D?) + O(D?) time. 


1. A more detailed explanation can be found at https://bit.ly/2I5660K. 
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Figure 20.6: Reconstruction error on MNIST vs number of latent dimensions used by PCA. (a) Training set. 
(b) Test set. Generated by pcaOverfitDemo.ipynb. 


20.1.4 Choosing the number of latent dimensions 


In this section, we discuss how to choose the number of latent dimensions L for PCA. 


20.1.4.1 Reconstruction error 


Let us define the reconstruction error on some dataset D incurred by the model when using L 
dimensions: 


1 : 
L= ID] > lan — ĉnll? (20.25) 
nEeD 


where the reconstruction is given by ĉn = Wzn + p, where zn = WT (æn — p) and p is the empirical 
mean, and W is estimated as above. Figure 20.6(a) plots £z vs L on the MNIST training data. We 
see that it drops off quite quickly, indicating that we can capture most of the empirical correlation of 
the pixels with a small number of factors. 

Of course, if we use L = rank(X), we get zero reconstruction error on the training set. To avoid 
overfitting, it is natural to plot reconstruction error on the test set. This is shown in Figure 20.6(b). 
Here we see that the error continues to go down even as the model becomes more complex! Thus 
we do not get the usual U-shaped curve that we typically expect to see in supervised learning. 
The problem is that PCA is not a proper generative model of the data: If you give it more latent 
dimensions, it will be able to approximate the test data more accurately. (A similar problem arises if 
we plot reconstruction error on the test set using K-means clustering, as discussed in Section 21.3.7.) 
We discuss some solutions to this below. 


20.1.4.2 Scree plots 


A common alternative to plotting reconstruction error vs L is to use something called a scree 
plot, which is a plot of the eigenvalues A; vs j in order of decreasing magnitude. One can show 
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Figure 20.7: (a) Scree plot for training set, corresponding to Figure 20.6(a). (b) Fraction of variance explained. 
Generated by pcaOverfitDemo.ipynb. 


(Exercise 20.4) that 


D 
Lr= yA (20.26) 


j=L+1 


Thus as the number of dimensions increases, the eigenvalues get smaller, and so does the reconstruction 
error, as shown in Figure 20.7a.? A related quantity is the fraction of variance explained, defined 
as 


Do y 
f= ee (20.27) 


This captures the same information as the scree plot, but goes up with L (see Figure 20.7b). 


20.1.4.3 Profile likelihood 


Although there is no U-shape in the reconstruction error plot, there is sometimes a “knee” or “elbow” 
in the curve, where the error suddenly changes from relatively large errors to relatively small. The 
idea is that for L < L*, where L* is the “true” latent dimensionality (or number of clusters), the rate 
of decrease in the error function will be high, whereas for L > L*, the gains will be smaller, since the 
model is already sufficiently complex to capture the true distribution. 

One way to automate the detection of this change in the gradient of the curve is to compute 
the profile likelihood, as proposed in [ZG06]. The idea is this. Let Ag be some measure of the 
error incurred by a model of size L, such that A; > A2 > --- > Apmax. In PCA, these are the 
eigenvalues, but the method can also be applied to the reoconstruction error from K-means clustering 
(see Section 21.3.7). Now consider partitioning these values into two groups, depending on whether 
k < Lor k > L, where L is some threshold which we will determine. To measure the quality of L, 


2. The reason for the term “scree plot” is that “the plot looks like the side of a mountain, and ’scree’ refers to the 
debris fallen from a mountain and lying at its base”. (Quotation from Kenneth Janda, https://bit.ly/2kqGiyWw.) 
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Figure 20.8: Profile likelihood corresponding to PCA model in Figure 20.6(a). Generated by pcaOverfit- 
Demo.ipynb. 


we will use a simple change-point model, where Ag ~ N (u1,0°) if k < L, and Ak ~ N (u2,0°) if 
k > L. (It is important that o? be the same in both models, to prevent overfitting in the case where 
one regime has less data than the other.) Within each of the two regimes, we assume the Ag are 
iid, which is obviously incorrect, but is adequate for our present purposes. We can fit this model for 
each L = 1 : L™** by partitioning the data and computing the MLEs, using a pooled estimate of the 
variance: 


<I, Àk 
m(L) = D (20.28) 

= as, Ak 
Ha(L) = Faa CT, (20.29) 

Ak — ba (L))? + Ss (Ak — ull)? 
pqr) = Ets = MUD + Eno Oe = oa) aa 
We can then evaluate the profile log likelihood 
L pmex 

KL) = Slog M(Aglua(L),07(L)) + $, log NArlu(L), 07(L)) (20.31) 


k=1 k=L+1 


This is illustrated in Figure 20.8. We see that the peak L* = argmax ¢(L) is well determined. 


20.2 Factor analysis * 


PCA is asimple method for computing a linear low-dimensional representation of data. In this section, 
we present a generalization of PCA known as factor analysis. This is based on a probabilistic 
model, which means we can treat it as a building block for more complex models, such as the mixture 
of FA models in Section 20.2.6, or the nonlinear FA model in Section 20.3.5. We can recover PCA as 
a special limiting case of FA, as we discuss in Section 20.2.2. 
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20.2.1 Generative model 


Factor analysis corresponds to the following linear-Gaussian latent variable generative model: 


p(z) = N(z|Mo, Zo) (20.32) 
p(a|z,0) = N(a|Wz + p, Y) (20.33) 


where W is a D x L matrix, known as the factor loading matrix, and W is a diagonal D x D 
covariance matrix. 

FA can be thought of as a low-rank version of a Gaussian distribution. To see this, note that the 
induced marginal distribution p(æ|0) is a Gaussian (see Equation (3.38) for the derivation): 


p(x|0) = | Newz + u, Y)N (z| uo, Xo)dz (20.34) 
= N (x|W uo + u, ¥ + WEW!) (20.35) 
Hence E [a] = W uo + u and Cov [æ] = WCov [z] WT + © = WEW! + Y. From this, we see that 


we can set uo = 0 without loss of generality, since we can always absorb Wy into u. Similarly, we 
can set Xo = I without loss of generality, since we can always absorb a correlated prior by using a 
1 


new weight matrix, W = WX, °. After these simplifications we have 


p(z) = N (20, I) (20.36) 
p(a|z) = N(<|Wz + p, ®) (20.37) 
p(a) =N (zju, WW' + Y) (20.38) 


For example, suppose where L = 1, D = 2 and W = o?°I. We illustrate the generative process in 
this case in Figure 20.9. We can think of this as taking an isotropic Gaussian “spray can”, representing 
the likelihood p(a|z), and “sliding it along” the 1d line defined by wz + p as we vary the 1d latent 
prior z. This induces an elongated (and hence correlated) Gaussian in 2d. That is, the induced 
distribution has the form p(a#) = N (æ|u, ww! +07). 

In general, FA approximates the covariance matrix of the visible vector using a low-rank decompo- 
sition: 


C = Cov [a] = WW! + Y (20.39) 


This only uses O(LD) parameters, which allows a flexible compromise between a full covariance 
Gaussian, with O(D?) parameters, and a diagonal covariance, with O(D) parameters. 

From Equation (20.39), we see that we should restrict Y to be diagonal, otherwise we could set 
W = 0, thus ignoring the latent factors, while still being able to model any covariance. The marginal 
variance of each visible variable is given by Y [rq] = = w2, + Wa, where the first term is the 
variance due to the common factors, and the second wg term is called the uniqueness, and is the 
variance term that is specific to that dimension. 

We can estimate the parameters of an FA model using EM (see Section 20.2.3). Once we have 
fit the model, we can compute probabilistic latent embeddings using p(z|a). Using Bayes rule for 
Gaussians we have 


p(z|a) = N(z|W'C (a — u), 1 — W'C tW) (20.40) 
where C is defined in Equation (20.39). 
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Figure 20.9: Illustration of the FA generative process, where we have L = 1 latent dimension generating 
D = 2 observed dimensions; we assume W = o7I. The latent factor has value z € R, sampled from 
p(z); this gets mapped to a 2d offset 6 = zw, where w € R?, which gets added to u to define a Gaussian 
p(a|z) = N(alu+6,071). By integrating over z, we “slide” this circular Gaussian “spray can” along the 
principal component aris w, which induces elliptical Gaussian contours in x space centered on p. Adapted 
from Figure 12.9 of [Bis06]. 


20.2.2 Probabilistic PCA 


In this section, we consider a special case of the factor analysis model in which W has orthonormal 
columns, and © = o7I. This model is called probabilistic principal components analysis 
(PPCA) [TB99], or sensible PCA [Row97]. The marginal distribution on the visible variables has 
the form 


p(x2|0) = | Naw, o’I)N (z|0, I)dz = N (x|, C) (20.41) 
where 
C=WW'+0°I (20.42) 


The log likelihood for PPCA is given by 
ND N Š 
log p(X|p1, W, 0?) = -2 log(2n) — Flog |C|— 5 X (æn — 4) Can — p) (20.43) 
n=1 
The MLE for pm is Z. Plugging in gives 
N 
log p(X|u, W, 0?) = = [D log(2m) + log |C| + tr(C~*S)] (20.44) 


N = T: = : ! 
where S = $ 0,1 (@n — Z)(£n — T)" is the empirical covariance matrix. 


In [TB99; Row97]| they show that the maximum of this objective must satisfy 
W =U,(L, -PIR (20.45) 
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where Uz is a D x L matrix whose columns are given by the L eigenvectors of S with largest 
eigenvalues, Ly is the L x L diagonal matrix of eigenvalues, and R is an arbitrary L x L orthogonal 
matrix, which (WLOG) we can take to be R = I. In the noise-free limit, where o? = 0, we see that 


1 
Write = UzL?, which is proportional to the PCA solution. 
The MLE for the observation variance is 


D 

1 

2 — ` Ài (20.46) 

o i ; 
D-L i=L41 


which is the average distortion associated with the discarded dimensions. If L = D, then the 
estimated noise is 0, since the model collapses to z = a. 

To compute the likelihood p(X|u,W,o7), we need to evaluate C~! and log|C], where C is a 
D x D matrix. To do this efficiently, we can use the matrix inversion lemma to write 


C~ = o° [I - WM 'W'] (20.47) 
where the L x L dimensional matrix M is given by 
M=W'W +07! (20.48) 


When we plug in the MLE for W from Equation (20.45) (using R = I) we find 


M = U; (Lz; — o71)U}, + 071 (20.49) 
and hence 
C~! = o° [I — Uz (Lz — o°I)AT +U] ] (20.50) 
L 
log |C| = (D — L) logo? + X` log A; (20.51) 
j=1 


Thus we can avoid all matrix inversions (since Ay’ = diag(1/A;)). 
To use PPCA as an alternative to PCA, we need to compute the posterior mean E [z|a], which is 
the equivalent of the encoder model. Using Bayes rule for Gaussians we have 


p(z|a) = N(z|M~!W' (æ — p),0?M~1) (20.52) 


where M is defined in Equation (20.48). In the ø? = 0 limit, the posterior mean using the MLE 
parameters becomes 


i [z|æ] = (W'W)-!W! (æ — T) (20.53) 


which is the orthogonal projection of the data into the latent space, as in standard PCA. 


20.2.3 EM algorithm for FA/PPCA 


In this section, we describe one method for computing the MLE for the FA model using the EM 
algorithm, based on [RT82; GH96]. 
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20.2.3.1 EM for FA 


In the E step, we compute the posterior embeddings 


p(zi|@i,0) = N(zi|mi, Bi) (20.54) 
»,4(I,+ W's tw)! (20.55) 
mi = X(W (x; — p)) (20.56) 


In the M step, it is easiest to estimate and W at the same time, by defining W = (W, m), 
Z = (z,1), Also, define 


bi Ê E[z|x,] = [mi;1 (20.57) 
PO a [zzT\xi ulz £i 
C; 2 t |227 læs = ( o j ') ues) 
Then the M step is as follows: 


-1 


(20.59) 


v= L diag B (2: = Wo) a} (20.60) 


a 


Note that these updates are for “vanilla” EM. A much faster version of this algorithm, based on 
ECM, is described in [ZY08]. 


20.2.3.2 EM for (P)PCA 


We can also use EM to fit the PPCA model, which provides a useful alternative to eigenvector 
methods. This relies on the probabilistic formulation of PCA. However the algorithm continues to 
work in the zero noise limit, ø? = 0, as shown by [Row97]. 

In particular, let Z = Z" bea Lx N matrix storing the posterior means (low-dimensional 
representations) along its columns. Similarly, let X = X" store the original data along its columns. 
From Equation (20.52), when o? = 0, we have 


Z = (W'Ww) 'w'x (20.61) 


This constitutes the E step. Notice that this is just an orthogonal projection of the data. 
From Equation 20.59, the M step is given by 


W= fs i te > |i] te" (20.62) 


i 


where we exploited the fact that © = Cov [z;|x;, 0] = OI when o? = 0. 

It is worth comparing this expression to the MLE for multi-output linear regression (Equation 11.2), 
which has the form W = (X, yix T)(X; via} )~'. Thus we see that the M step is like linear regression 
where we replace the observed inputs by the expected values of the latent variables. 
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In summary, here is the entire algorithm: 
Z = (W'W)-!w’X (E step) (20.63) 
W = XZ’ (ŽŽ)! (M step) (20.64) 


[TB99] showed that the only stable fixed point of the EM algorithm is the globally optimal solution. 
That is, the EM algorithm converges to a solution where W spans the same linear subspace as that 
defined by the first L eigenvectors. However, if we want W to be orthogonal, and to contain the 
eigenvectors in descending order of eigenvalue, we have to orthogonalize the resulting matrix (which 
can be done quite cheaply). Alternatively, we can modify EM to give the principal basis directly 
[AO03]. 

This algorithm has a simple physical analogy in the case D = 2 and L = 1 [Row97]. Consider 
some points in R? attached by springs to a rigid rod, whose orientation is defined by a vector w. Let 
zi be the location where the 7’th spring attaches to the rod. In the E step, we hold the rod fixed, and 
let the attachment points slide around so as to minimize the spring energy (which is proportional to 
the sum of squared residuals). In the M step, we hold the attachment points fixed and let the rod 
rotate so as to minimize the spring energy. See Figure 20.10 for an illustration. 


20.2.3.3 Advantages 


EM for PCA has the following advantages over eigenvector methods: 


e EM can be faster. In particular, assuming N, D >> L, the dominant cost of EM is the projection 
operation in the E step, so the overall time is O(TLN D), where T is the number of iterations. 
[Row97| showed experimentally that the number of iterations is usually very small (the mean 
was 3.6), regardless of N or D. (This result depends on the ratio of eigenvalues of the empirical 
covariance matrix.) This is much faster than the O(min(ND?, DN?)) time required by straightfor- 
ward eigenvector methods, although more sophisticated eigenvector methods, such as the Lanczos 
algorithm, have running times comparable to EM. 


e EM can be implemented in an online fashion, i.e., we can update our estimate of W as the data 
streams in. 


e EM can handle missing data in a simple way (see e.g., [IR10; DJ15]). 


e EM can be extended to handle mixtures of PPCA/ FA models (see Section 20.2.6). 


e EM can be modified to variational EM or to variational Bayes EM to fit more complex models 
(see e.g., Section 20.2.7). 
20.2.4 Unidentifiability of the parameters 


The parameters of a FA model are unidentifiable. To see this, consider a model with weights W and 
observation covariance YW. We have 


Cov [x] = WE [zz"] W' +E [ee"] = WW' + Y (20.65) 
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E step 2 M step 2 


(c) (a) 


Figure 20.10: Illustration of EM for PCA when D = 2 and L = 1. Green stars are the original data points, 
black circles are their reconstructions. The weight vector w is represented by blue line. (a) We start with a 
random initial guess of w. The E step is represented by the orthogonal projections. (b) We update the rod w 
in the M step, keeping the projections onto the rod (black circles) fixed. (c) Another E step. The black circles 
can ’slide’ along the rod, but the rod stays fixed. (d) Another M step. Adapted from Figure 12.12 of [Bis06]. 
Generated by pcakmStepByStep.ipynb. 


Now consider a different model with weights W = WR, where R is an arbitrary orthogonal rotation 
matrix, satisfying RR! = I. This has the same likelihood, since 


Cov [x] = WE [zz"] W' +E [ee'] = WRR'W' + Y = WW' + Y (20.66) 


Geometrically, multiplying W by an orthogonal matrix is like rotating z before generating x; but 
since z is drawn from an isotropic Gaussian, this makes no difference to the likelihood. Consequently, 
we cannot uniquely identify W, and therefore cannot uniquely identify the latent factors, either. 
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To break this symmetry, several solutions can be used, as we discuss below. 

e Forcing W to have orthonormal columns. Perhaps the simplest solution to the identifiability 
problem is to force W to have orthonormal columns. This is the approach adopted by PCA. 
The resulting posterior estimate will then be unique, up to permutation of the latent dimensions. 
(In PCA, this ordering ambiguity is resolved by sorting the dimensions in order of decreasing 
eigenvalues of W.) 

e Forcing W to be lower triangular. One way to resolve permutation unidentifiability, which 
is popular in the Bayesian community (e.g., [LW04c]), is to ensure that the first visible feature is 
only generated by the first latent factor, the second visible feature is only generated by the first 
two latent factors, and so on. For example, if L = 3 and D = 4, the corresponding factor loading 
matrix is given by 


(20.67) 
w31 W32 W33 


W41 W42 W43 


We also require that wkk > 0 for k = 1 : L. The total number of parameters in this constrained 
matrix is D+ DL— L(L—1)/2, which is equal to the number of uniquely identifiable parameters in 
FA.° The disadvantage of this method is that the first L visible variables, known as the founder 
variables, affect the interpretation of the latent factors, and so must be chosen carefully. 

e Sparsity promoting priors on the weights. Instead of pre-specifying which entries in W are 
zero, we can encourage the entries to be zero, using 44 regularization [ZHT06], ARD [Bis99; ABO8], 
or spike-and-slab priors [Rat+09]. This is called sparse factor analysis. This does not necessarily 
ensure a unique MAP estimate, but it does encourage interpretable solutions. 

e Choosing an informative rotation matrix. There are a variety of heuristic methods that try 
to find rotation matrices R which can be used to modify W (and hence the latent factors) so as 
to try to increase the interpretability, typically by encouraging them to be (approximately) sparse. 
One popular method is known as varimax [Kai58]. 

e Use of non-Gaussian priors for the latent factors. If we replace the prior on the latent 
variables, p(z), with a non-Gaussian distribution, we can sometimes uniquely identify W, as well 
as the latent factors. See e.g., [KIKKH20] for details. 


20.2.5 Nonlinear factor analysis 


The FA model assumes the observed data can be modeled as arising from a linear mapping from a 
low-dimensional set of Gaussian factors. One way to relax this assumption is to let the mapping 
from z to æ be a nonlinear model, such as a neural network. That is, the model becomes 


p(x) = [Nolte 0), WN (z|0, I)dz (20.68) 


3. We get D parameters for Y and DL for W, but we need to remove L(L — 1)/2 degrees of freedom coming from R, 
since that is the dimensionality of the space of orthogonal matrices of size L x L. To see this, note that there are L — 1 
free parameters in R in the first column (since the column vector must be normalized to unit length), there are L — 2 
free parameters in the second column (which must be orthogonal to the first), and so on. 
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Figure 20.11: Mixture of factor analyzers as a PGM. 


This is called nonlinear factor analysis. Unfortunately we can no longer compute the posterior or 
the MLE exactly, so we need to use approximate methods. In Section 20.3.5, we discuss variational 
autoencoders, which is the most common way to approximate a nonlinear FA model. 


20.2.6 Mixtures of factor analysers 


The factor analysis model (Section 20.2) assumes the observed data can be modeled as arising from 
a linear mapping from a low-dimensional set of Gaussian factors. One way to relax this assumption 
is to assume the model is only locally linear, so the overall model becomes a (weighted) combination 
of FA models; this is called a mixture of factor analysers. The overall model for the data is a 
mixture of linear manifolds, which can be used to approximate an overall curved manifold. 

More precisely, let latent indicator Mn € {1,..., K}, specifying which subspace (cluster) we should 
use to generate the data. If m, = k, we sample z, from a Gaussian prior and pass it through the 
W,, matrix and add noise, where W; maps from the Z-dimensional subspace to the D-dimensional 
visible space.* More precisely, the model is as follows: 


PlEn| Zn, Mn = k, 0) = N (Trl Hk + WrZn, Ux) (20.69) 
P(2n|8) = N(2n|0, T) (20.70) 
p(m,|) = Cat(m,|7) (20.71) 


This is called a mixture of factor analysers (MFA) [GH96]. The corresponding distribution in 
the visible space is given by 


p(2|0) = X` plc = k) / dz p(z|c)p(a|z, c) = dom fae N (z|up, DN (a|W z, 071) (20.72) 
k k 


In the special case that Wy = 071, we get a mixture of PPCA models (although it is difficult to 
ensure orthogonality of the Wọ in this case). See Figure 20.12 for an example of the method applied 
to some 2d data. 

We can think of this as a low-rank version of a mixture of Gaussians. In particular, this model 
needs O(K LD) parameters instead of the O(K D?) parameters needed for a mixture of full covariance 
Gaussians. This can reduce overfitting. 


4. If we allow zn to depend on mn, we can let each subspace have a different dimensionality, as suggested in [KS15]. 
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Figure 20.12: Mixture of PPCA models fit to a 2d dataset, using L = 1 latent dimensions. (a) K = 1 mizture 
components. (b) K = 10 misture components. Generated by mixPpcaDemo.ipynb. 


20.2.7 Exponential family factor analysis 


So far we have assumed the observed data is real-valued, so x, € RP. If we want to model other 
kinds of data (e.g., binary or categorical), we can simply replace the Gaussian output distribution 
with a suitable member of the exponential family, where the natural parameters are given by a linear 
function of zn. That is, we use 


P(@n|2n) = exp(T (w)"@ + h(a) — g(0)) (20.73) 


where the N x D matrix of natural parameters is assumed to be given by the low rank decomposition 
© = ZW, where Z is N x L and W is L x D. The resulting model is called exponential family 
factor analysis. 

Unlike the linear-Gaussian FA, we cannot compute the exact posterior p(Zn|£n, W) due to the 
lack of conjugacy between the expfam likelihood and the Gaussian prior. Furthermore, we cannot 
compute the exact marginal likelihood either, which prevents us from finding the optimal MLE. 

[CDS02] proposed a coordinate ascent method for a deterministic variant of this model, known as 
exponential family PCA. This alternates between computing a point estimate of z,, and W. This 
can be regarded as a degenerate version of variational EM, where the E step uses a delta function 
posterior for Zn. [GS08] present an improved algorithm that finds the global optimum, and [Ude+16] 
presents an extension called generalized low rank models, that covers many different kinds of 
loss function. 

However, it is often preferable to use a probabilistic version of the model, rather than computing 
point estimates of the latent factors. In this case, we must represent the posterior using a non- 
degenerate distribution to avoid overfitting, since the number of latent variables is proportional to 
the number of datacases [WCS08]. Fortunately, we can use a non-degenerate posterior, such as a 
Gaussian, by optimizing the variational lower bound. We give some examples of this below. 


20.2.7.1 Example: binary PCA 
Consider a factored Bernoulli likelihood: 


p(w|z) = | | Ber(zaļo(waz)) (20.74) 
d 
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Figure 20.13: (a) 150 synthetic 16 dimensional bit vectors. (b) The 2d embedding learned by binary PCA, fit 
using variational EM. We have color coded points by the identity of the true “prototype” that generated them. 
(c) Predicted probability of being on. (d) Thresholded predictions. Generated by binary fa demo.ipynb. 


Suppose we observe N = 150 bit vectors of length D = 16. Each example is generated by choosing 
one of three binary prototype vectors, and then by flipping bits at random. See Figure 20.13(a) 
for the data. We can fit this using the variational EM algorithm (see [Tip98] for details). We 
use L = 2 latent dimensions to allow us to visualize the latent space. In Figure 20.13(b), we plot 


Z | Zn|£n, W|. We see that the projected points group into three distinct clusters, as is to be expected. 


In Figure 20.13(c), we plot the reconstructed version of the data, which is computed as follows: 


Plêna = 1|%,) = Jin P(Zn|£n)PlênalZn) (20.75) 


If we threshold these probabilities at 0.5 (corresponding to a MAP estimate), we get the “denoised” 
version of the data in Figure 20.13(d). 
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Figure 20.14: Gaussian latent factor models for paired data. (a) Supervised PCA. (b) Partial least squares. 


20.2.7.2 Example: categorical PCA 


We can generalize the model in Section 20.2.7.1 to handle categorical data by using the following 
likelihood: 


p(w|z) = | | Cat(xa|softmax(Waz)) (20.76) 
d 


We call this categorical PCA (CatPCA). A variational EM algorithm for fitting this is described 
in [Kha+10]. 


20.2.8 Factor analysis models for paired data 


In this section, we discuss linear-Gaussian factor analysis models when we have two kinds of observed 
variables, x € RP= and y € RP”, which are paired. These often correspond to different sensors or 
modalities (e.g., images and sound). We follow the presentation of [Vir10]. 


20.2.8.1 Supervised PCA 


In supervised PCA [Yu+06], we model the joint p(x, y) using a shared low-dimensional represen- 
tation using the following linear Gaussian model: 


P(2n) = N(2n|0, Iz) (20.77) 
P(®n|Zn, 9) = N(2n|W22n,721p,) (20.78) 
P(Yalzn, 0) = N (Yn|Wyzn, opip) (20.79) 


This is illustrated as a graphical model in Figure 20.14a. The intuition is that z, is a shared latent 
subspace, that captures features that x, and y, have in common. The variance terms oz and oy 
control how much emphasis the model puts on the two different signals. If we put a prior on the 
parameters 0 = (Wz, Wy, Ox, 0y), we recover the Bayesian factor regression model of [Wes03]. 
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We can marginalize out Zn to get p(Yn|£n). If Yn is a scalar, this becomes 


P(Yn| en, 0) = N (Yn |x},v, wy Cwy + 07) (20.80) 
C=(I+0,?Wiw,)! (20.81) 
v =0,7W,Cwy (20.82) 


To apply this to the classification setting, we can use supervised ePCA [Guo09], in which we 
replace the Gaussian p(y|z) with a logistic regression model. 

This model is completely symmetric in x and y. If our goal is to predict y from æ via the latent 
bottleneck z, then we might want to upweight the likelihood term for y, as proposed in [Ris+-08]. 
This gives 


P(X, Y, ZO) = p(¥|Z, Wy )p(X|Z, W2)*p(Z) (20.83) 


where a < 1 controls the relative importance of modeling the two sources. The value of œ can be 
chosen by cross-validation. 


20.2.8.2 Partial least squares 


Another way to improve the predictive performance in supervised tasks is to allow the inputs x 
to have their own “private” noise source that is independent on the target variable, since not all 
variation in æ is relevant for predictive purposes. We can do this by introducing an extra latent 
variable z* just for the inputs, that is different from z° that is the shared bottleneck between £n 
and Yn. In the Gaussian case, the overall model has the form 


P(2n) = N (2, |0, DN (2710, I) (20.84) 
PlEn| Zn, 0) = N(@n|WaZp, + By Zn a7) (20.85) 
P(Yn|Zn,9) = N(Yn|Wyzn, oI) (20.86) 


See Figure 20.14b. MLE for @ in this model is equivalent to the technique of partial least squares 
(PLS) [Gus01; Nou+02; Sun+09]. 


20.2.8.3 Canonical correlation analysis 


In some cases, we want to use a fully symmetric model, so we can capture the dependence between x 
and y, while allowing for domain-specific or “private” noise sources. We can do this by introducing a 
latent variable z7 just for x», a latent variable z¥ just for yn, and a shared latent variable z°. In 
the Gaussian case, the overall model has the form 

D(Zn) = N (22 |0, DN (220, DN (z4|0, 1) 20.87 
20.88 
20.89 


20.90 


PlEnl|Zn, 0) = N (an|Wa2n T Bizp, zI) 


( 
( 
( 
P(Yn|Zn; 0) = N (yn|WyZn + Byzi, o,) ( 


) 
) 
) 
) 


where W, and W, are L5 x D dimensional, V, is L” x D dimensional, and V, is L” x D dimensional. 
See Figure 20.15 for the PGM. 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


20.3. Autoencoders 677 


Figure 20.15: Canonical correlation analysis as a PGM. 


If we marginalize out all the latent variables, we get the following distribution on the visibles 
(where we assume 0, = dy = 0): 


p(n, Yn) = J dnvlen)plern. Yn|Zn) = N (an, Yn| H, ww! + o’T) (20.91) 


where p = (Hz; Hy), and W = [W,; W,]. Thus the induced covariance is the following low rank 
matrix: 


UW w wa 


Wowi ww! (20.92) 

[BJ05] showed that MLE for this model is equivalent to a classical statistical method known as 
canonical correlation analysis or CCA [Hot36]. However, the PGM perspective allows us to 
easily generalize to multiple kinds of observations (this is known as generalized CCA [Hor61]) or 
to nonlinear models (this is known as deep CCA [WLL16; SNM16]), or exponential family CCA 
[KVK10]. See [Uur+17] for further discussion of CCA and its extensions. 


20.3 Autoencoders 


We can think of PCA (Section 20.1) and factor analysis (Section 20.2) as learning a (linear) mapping 
from a > z, called the encoder, fe, and learning another (linear) mapping z > «æ , called the 
decoder, fz. The overall reconstruction function has the form r(x) = fa(fe(a)). The model is 
trained to minimize L(0) = ||r(a) — a||3. More generally, we can use £L(0) = — log p(a|r(a)). 

In this section, we consider the case where the encoder and decoder are nonlinear mappings 
implemented by neural networks. This is called an autoencoder. If we use an MLP with one hidden 
layer, we get the model shown Figure 20.16. We can think of the hidden units in the middle as a 
low-dimensional bottleneck between the input and its reconstruction. 

Of course, if the hidden layer is wide enough, there is nothing to stop this model from learning the 
identity function. To prevent this degenerate solution, we have to restrict the model in some way. The 
simplest approach is to use a narrow bottleneck layer, with L < D; this is called an undercomplete 
representation. The other approach is to use L >> D, known as an overcomplete representation, 
but to impose some other kind of regularization, such as adding noise to the inputs, forcing the 
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Figure 20.16: An autoencoder with one hidden layer. 


activations of the hidden units to be sparse, or imposing a penalty on the derivatives of the hidden 
units. We discuss these options in more detail below. 


20.3.1 Bottleneck autoencoders 


We start by considering the special case of a linear autoencoder, in which there is one hidden layer, 
the hidden units are computed using z = Wj 2, and the output is reconstructed using = W2z, 
where W; is a L x D matrix, W2 is a D x L matrix, and L < D. Hence £ = W2W,a2 = Wa 
is the output of the model. If we train this model to minimize the squared reconstruction error, 
L(W) = og |En — Wa,,||3, one can show [BH89; KJ95] that W is an orthogonal projection onto 
the first L eigenvectors of the empirical covariance matrix of the data. This is therefore equivalent to 
PCA. 

If we introduce nonlinearities into the autoencoder, we get a model that is strictly more powerful 
than PCA, as proved in [JHG00]. Such methods can learn very useful low dimensional representations 
of data. 

Consider fitting an autoencoder to the Fashion MNIST dataset. We consider both an MLP 
architecture (with 2 layers and a bottleneck of size 30), and a CNN based architecture (with 3 layers 
and a 3d bottleneck with 64 channels). We use a Bernoulli likelihood model and binary cross entropy 
as the loss. Figure 20.17 shows some test images and their reconstructions. We see that the CNN 
model reconstructs the images more accurately than the MLP model. However, both models are 
small, and were only trained for 5 epochs; results can be improved by using larger models, and 
training for longer. 

Figure 20.18 visualizes the first 2 (of 30) latent dimensions produced by the MLP-AE. More 
precisely, we plot the tSNE embeddings (see Section 20.4.10), color coded by class label. We also 
show some corresponding images from the dataset, from which the embeddings were derived. We see 
that the method has done a good job of separating the classes in a fully unsupervised way. We also 
see that the latent space of the MLP and CNN models is very similar (at least when viewed through 
this 2d projection). 
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Figure 20.17: Results of applying an autoencoder to the Fashion MNIST data. Top row are first 5 images 
from validation set. Bottom row are reconstructions. (a) MLP model (trained for 20 epochs). The encoder is 
an MLP with architecture 784-100-380. The decoder is the mirror image of this. (b) CNN model (trained for 5 
epochs). The encoder is a CNN model with architecture Conv2D(16, 3 x 3, same, selu), MaxPool2D (2x2), 
Conv2D(82, 3 x 3, same, selu), MaxPool2D(2 x 2), Conv2D(64, 3 x 3, same, selu), MaxPool2D(2 x 2). The 


decoder is the mirror image of this, using transposed convolution and without the max pooling layers. Adapted 
from Figure 17.4 of [Gér19]. Generated by ae_ mnist_ tf.ipynb. 


(a) (b) 


Figure 20.18: tSNE plot of the first 2 latent dimensions of the Fashion MNIST validation set using an 
autoencoder. (a) MLP. (b) CNN. Adapted from Figure 17.5 of [Gér19]. Generated by ae_mnist_tf.ipynb. 


20.3.2 Denoising autoencoders 


One useful way to control the capacity of an autoencoder is to add noise to its input, and then 
train the model to reconstruct a clean (uncorrupted) version of the original input. This is called a 
denoising autoencoder [Vin+10al. 

We can implement this by adding Gaussian noise, or using Bernoulli dropout. Figure 20.19 shows 
some reconstructions of corrupted images computed using a DAE. We see that the model is able to 
“hallucinate” details that are missing in the input, since it has seen similar images before, and can 
store this information in the parameters of the model. 

Suppose we train a DAE using Gaussian corruption and squared error reconstruction, i.e., we use 
pe(ž|x£) = N(#|x, 071) and (x, r(%)) = ||e||2, where e(x) = r(#)—z= is the residual error for example 
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Figure 20.19: Denoising autoencoder (MLP architecture) applied to some noisy Fashion MNIST images from 
the validation set. (a) Gaussian noise. (b) Bernoulli dropout noise. Top row: input. Bottom row: output. 
Adapted from Figure 17.9 of [Gér19]. Generated by ae_ mmnist_ tf.ipynb. 
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Figure 20.20: The residual error from a DAE, e(x) = r(&) — x, can learn a vector field corresponding to 
the score function. Arrows point towards higher probability regions. The length of the arrow is proportional 
to \|e(a)||, so points near the 1d data manifold (represented by the curved line) have smaller arrows. From 
Figure 5 of [AB14]. Used with kind permission of Guillaume Alain. 


x. Then one can show [AB 14] the remarkable result that, as ø > 0 (and with a sufficiently powerful 
model and enough data), the residuals approximate the score function, which is the log probability 
of the data, i.e., e(a) ~% Vz log p(a). That is, the DAE learns a vector field, corresponding to the 
gradient of the log data density. Thus points that are close to the data manifold will be projected 
onto it via the sampling process. See Figure 20.20 for an illustration. 


20.3.3 Contractive autoencoders 


A different way to regularize autoencoders is by adding the penalty term 


Oz, 2) = |) R = xP IIe uÈ (20.93) 
k 


to the reconstruction loss, where hę is the value of the k’th hidden embedding unit. That is, we 
penalize the Frobenius norm of the encoder’s Jacobian. This is called a contractive autoencoder 
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[Rif+11]. (A linear operator with Jacobian J is called a contraction if ||Ja|| < 1 for all unit-norm 
inputs æ.) 

To understand why this is useful, consider Figure 20.20. We can approximate the curved low- 
dimensional manifold by a series of locally linear manifolds. These linear approximations can be 
computed using the Jacobian of the encoder at each point. By encouraging these to be contractive, 
we ensure the model “pushes” inputs that are off the manifold to move back towards it. 

Another way to think about CAEs is as follows. To minimize the penalty term, the model would 
like to ensure the encoder is a constant function. However, if it was completely constant, it would 
ignore its input, and hence incur high reconstruction cost. Thus the two terms together encourage 
the model to learn a representation where only a few units change in response to the most significant 
variations in the input. 

One possible degenerate solution is that the encoder simply learns to multiply the input by a 
small constant e (which scales down the Jacobian), followed by a decoder that divides by e (which 
reconstructs perfectly). To avoid this, we can tie the weights of the encoder and decoder, by setting 
the weight matrix for layer £ of fa to be the transpose of the weight matrix for layer £ of fe, but 
using untied bias terms. Unfortunately CAEs are slow to train, because of the expense of computing 
the Jacobian. 


20.3.4 Sparse autoencoders 


Yet another way to regularize autoencoders is to add a sparsity penalty to the latent activations of 
the form Q(z) = Aļ|z||ı. (This is called activity regularization.) 

An alternative way to implement sparsity, that often gives better results, is to use logistic units, 
and then to compute the expected fraction of time each unit k is on within a minibatch (call this 
qk), and ensure that this is close to a desired target value p, as proposed in [GBB11]. In particular, 
we use the regularizer 0(21:1,1:.) = à}; Dre (p || qx) for latent dimensions 1 : L and examples 
1: N, where p = (p,1—p) is the desired target distribution, and gx = (qx, 1 — qx) is the empirical 
distribution for unit k, computed using qk = x 2G I(2n,n = 1). 

Figure 20.21 shows the results when fitting an AE-MLP (with 300 hidden units) to Fashion MNIST. 
If we set \ = 0 (i.e., if we don’t impose a sparsity penalty), we see that the average activation value 
is about 0.4, with most neurons being partially activated most of the time. With the 44 penalty, we 
see that most units are off all the time, which means they are not being used at all. With the KL 
penalty, we see that about 70% of neurons are off on average, but unlike the ¢; case, we don’t see 
units being permanently turned off (the average activation level is 0.1). This latter kind of sparse 
firing pattern is similar to that observed in biological brains (see e.g., [Bey +19]). 


20.3.5 Variational autoencoders 


In this section, we discuss the variational autoencoder or VAE [KW14; RMW14; KW19al, which 
can be thought of as a probabilistic version of a deterministic autoencoder (Section 20.3) The principal 
advantage is that a VAE is a generative model that can create new samples, whereas an autoencoder 
just computes embeddings of input vectors. 

We discuss VAEs in detail in the sequel to this book, [Mur23]. However, in brief, the VAE combines 
two key ideas. First we create a non-linear extension of the factor analysis generative model, i.e., we 
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Figure 20.21: Neuron activity (in the bottleneck layer) for an autoencoder applied to Fashion MNIST. We 
show results for three models, with different kinds of sparsity penalty: no penalty (left column), Lı penalty 
(middle column), KL penalty (right column). Top row: Heatmap of 300 neuron activations (columns) across 
100 examples (rows). Middle row: Histogram of activation levels derived from this heatmap. Bottom row: 
Histogram of the mean activation per neuron, averaged over all examples in the validation set. Adapted from 
Figure 17.11 of [Gér19]. Generated by ae_ mmnist_ tf.ipynb. 


replace p(a|z) = N(a|Wz, 071) with 


po(2|z) = N (x| falz; 0), °T) (20.94) 
where fq is the decoder. For binary observations we should use a Bernoulli likelihood: 
D 
p(x|z, 0) = | | Ber(a:| fa(z; 8), 071) (20.95) 
i=1 


Second, we create another model, q(z|x), called the recognition network or inference network, 
that is trained simultaneously with the generative model to do approximate posterior inference. If 
we assume the posterior is Gaussian, with diagonal covariance, we get 


qo(z|@) = N (z| fe u(x; p), diag(fe o (æ; $))) (20.96) 
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Figure 20.22: Schematic illustration of a VAE. From a figure from http: //krasserm. github. i0/2018/ 
07/ 27/ dfc-vae/. Used with kind permission of Martin Krasser. 


where fe is the encoder. See Figure 20.22 for a sketch. 

The idea of training an inference network to “invert” a generative network, rather than running 
an optimization algorithm to infer the latent code, is called amortized inference. This idea was 
first proposed in the Helmholtz machine [Day-+-95]. However, that paper did not present a single 
unified objective function for inference and generation, but instead used the wake sleep method for 
training, which alternates between optimizing the generative model and inference model. By contrast, 
the VAE optimizes a variational lower bound on the log-likelihood, which is more principled, since it 
is a single unified objective. 


20.3.5.1 Training VAEs 


We cannot compute the exact marginal likelihood p(a|@) needed for MLE training, because posterior 
inference in a nonlinear FA model is intractable. However, we can use the inference network to 
compute an approximate posterior, g(z|a). We can then use this to compute the evidence lower 
bound or ELBO. For a single example æ, this is given by 


L(8, |x) = Lg (z|x) [log pe (x, z) = log q¢(2|2x)] (20.97) 
= Eq(z|a,¢) log p(x|z, 0)] — Dri (9(2|@, o) || p(z)) (20.98) 


This can be interpreted as the expected log likelihood, plus a regularizer, that penalizes the posterior 
from deviating too much from the prior. (This is different than the approach in Section 20.3.4, where 
we applied the KL penalty to the aggregate posterior in each minibatch.) 

The ELBO is a lower bound of the log marginal likelihood (aka evidence), as can be seen from 
Jensen’s inequality: 


(0, dl) = f ap(z\x) log pea aide (20.99) 
< log f goleia) ede = log po (x) (20.100) 
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Figure 20.23: Computation graph for VAEs. where p(z) = N(z|0,1), p(alz,@) = N(a|f(z),o7D, and 
q(z|xz, p) = N (z|u(x), U(a)). Red bores show sampling operations which are not differentiable. Blue boxes 
show loss layers (we assume Gaussian likelihoods and priors). (Left) Without the reparameterization trick. 
(Right) With the reparameterization trick. Gradients can flow from the output loss, back through the decoder 
and into the encoder. From Figure 4 of [Doe16]. Used with kind permission of Carl Doersch. 


Thus for fixed inference network parameters @, increasing the ELBO should increase the log likelihood 
of the data, similar to EM Section 8.7.2. 
20.3.5.2 The reparameterization trick 


In this section, we discuss how to compute the ELBO and its gradient. For simplicity, let us suppose 
that the inference network estimates the parameters of a Gaussian posterior. Since qg(z|x) is 
Gaussian, we can write 


z= fe u(x; o) + feol£; p) O€ (20.101) 


where e ~ N (0, I). Hence 


L(0, pla) = Een. [log po(æ|z = no (@) + o¢(#) © €)] — Dux (qe(z|æ) || p(2)) (20.102) 


Now the expectation is independent of the parameters of the model, so we can safely push gradients 
inside and use backpropagation for training in the usual way, by minimizing —E,~p [L(0, d|a)| wrt 
0 and @. This is known as the reparameterization trick. See Figure 20.23 for an illustration. 
The first term in the ELBO can be approximated by sampling e, scaling it by the output of the 
inference network to get z, and then evaluating log p(a|z) using the decoder network. 
The second term in the ELBO is the KL of two Gaussians, which has a closed form solution. In 
particular, inserting p(z) = V(z|0,I) and q(z) = V(z|p, diag(o)) into Equation (6.33), we get 


K 2 —0)2 
Dxt (q || p) = D hoec) yik + a 0) | 


k=1 


K 
= ! X [ogo} — o} — n? +1] (20.103) 


k=1 
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(a) (b) 


Figure 20.24: Reconstructing MNIST digits using a 20 dimensional latent space. Top row: input images. 
Bottom row: reconstructions. (a) VAE. Generated by vae_mnist_conv_lightning.ipynb. (b) Deterministic 
AE. Generated by ae_mnist_ conv.ipynb. 


(a) (e) 


Figure 20.25: Sampling MNIST digits using a 20 dimensional latent space. (a) VAE. Generated by 
vae_mnist_conv_lightning.ipynb. (b) Deterministic AE. Generated by ae_mnist_ conv.ipynb. 


20.3.5.3 Comparison of VAEs and autoencoders 


VAEs are very similar to autoencoders. In particular, the generative model, pg(a|z), acts like the 
decoder, and the inference network, q¢(z|a), acts like the encoder. The reconstruction abilities of 
both models are similar, as can be seen by comparing Figure 20.24a with Figure 20.24b. 

The primary advantage of the VAE is that it can be used to generate new data from random 
noise. In particular, we sample z from the Gaussian prior \V(z|0,1), and then pass this through the 
decoder to get E[x|z] = fa(z;@). The VAE’s decoder is trained to convert random points in the 
embedding space (generated by perturbing the input encodings) to sensible outputs. By contrast, 
the decoder for the deterministic autoencoder only ever gets as inputs the exact encodings of the 
training set, so it does not know what to do with random inputs that are outside what it was trained 
on. So a standard autoencoder cannot create new samples. This difference can be seen by comparing 
Figure 20.25a with Figure 20.25b. 

The reason the VAE is better at sample is that it embeds images into Gaussians in latent space, 
whereas the AE embeds images into points, which are like delta functions. The advantage of using a 
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Figure 20.26: tSNE projection of a 20 dimensional latent space. (a) VAE. Generated by 
vae_mnist_conv_lightning.ipynb. (b) Deterministic AE. Generated by ae_mnist_ conv.ipynb. 
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Figure 20.27: Linear interpolation between the left and right images in a 20 dimensional latent space. (a) 
VAE. (b) Deterministic AE. Generated by vae_mnist_ conv_lightning.ipynb. 


latent distribution is that it encourages local smoothness, since a given image may map to multiple 
nearby places, depending on the stochastic sampling. By contrast, in an AE, the latent space is 
typically not smooth, so images from different classes often end up next to each other. This difference 
can be seen by comparing Figure 20.26a with Figure 20.26b. 

We can leverage the smoothness of the latent space to perform image interpolation. Rather than 
working in pixel space, we can work in the latent space of the model. Specifically, let xı and x2 be 
two images, and let zı = Eqcz|x,) [Z] and z2 = E,z)z,) [Z] be their encodings. We can now generate 
new images that interpolate between these two anchors by computing z = Az, + (1 — A)z2, where 
0< à< 1, and then decoding by computing E [x|z]. This is called latent space interpolation. 
(The justification for taking a linear interpolation is that the learned manifold has approximately 
zero curvature, as shown in [SKTF18].) A VAE is more useful for latent space interpolation than an 
AE because its latent space is smoother, and because the model can generate from almost any point 
in latent space. This difference can be seen by comparing Figure 20.27a with Figure 20.27b. 
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Figure 20.28: Illustration of the tangent space and tangent vectors at two different points on a 2d curved 
manifold. From Figure 1 of [Bro+17a]. Used with kind permission of Michael Bronstein. 


20.4 Manifold learning * 


In this section, we discuss the problem of recovering the underlying low-dimensional structure in 
a high-dimensional dataset. This structure is often assumed to be a curved manifold (explained 
in Section 20.4.1), so this problem is called manifold learning or nonlinear dimensionality 
reduction. The key difference from methods such as autoencoders (Section 20.3) is that we will 
focus on non-parametric methods, in which we compute an embedding for each point in the training 
set, as opposed to learning a generic model that can embed any input vector. That is, the methods 
we discuss do not (easily) support out-of-sample generalization. However, they can be easier 
to fit, and are quite flexible. Such methods can be a useful for unsupervised learning (knowledge 
discovery), data visualization, and as a preprocessing step for supervised learning. See [AAB21] for a 
recent review of this field. 


20.4.1 What are manifolds? 


Roughly speaking, a manifold is a topological space which is locally Euclidean. One of the simplest 
examples is the surface of the earth, which is a curved 2d surface embedded in a 3d space. At each 
local point on the surface, the earth seems flat. 

More formally, a d-dimensional manifold ¥ is a space in which each point x € ¥ has a neighborhood 
which is topologically equivalent to a d-dimensional Euclidean space, called the tangent space, 
denoted 7; = T,%¥. This is illustrated in Figure 20.28. 

A Riemannian manifold is a differentiable manifold that associates an inner product operator 
at each point x in tangent space; this is assumed to depend smoothly on the position x. The inner 
product induces a notion of distance, angles, and volume. The collection of these inner products is 
called a Riemannian metric. It can be shown that any sufficiently smooth Riemannian manifold 
can be embedded into a Euclidean space of potentially higher dimension; the Riemannian inner 
product at a point then becomes Euclidean inner product in that tangent space. 


20.4.2 The manifold hypothesis 


Most “naturally occuring” high dimensional dataset lie a low dimensional manifold. This is called the 
manifold hypothesis [FMN16]. For example, consider the case of an image. Figure 20.29a shows a 
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Figure 20.29: Illustration of the image manifold. (a) An image of the digit 6 from the USPS dataset, of 
size 64 x 57 = 3,648. (b) A random sample from the space {0, 1}*°** reshaped as an image. (c) A dataset 
created by rotating the original image by one degree 360 times. We project this data onto its first two principal 
components, to reveal the underlying 2d circular manifold. From Figure 1 of [Law12]. Used with kind 
permission of Neil Lawrence. 


single image of size 64 x 57. This is a vector in a 3,648-dimensional space, where each dimension 
corresponds to a pixel intensity. Suppose we try to generate an image by drawing a random point 
in this space; it is unlikely to look like the image of a digit, as shown in Figure 20.29b. However, 
the pixels are not independent of each other, since they are generated by some lower dimensional 
structure, namely the shape of the digit 6. 

As we vary the shape, we will generate different images. We can often characterize the space of 
shape variations using a low-dimensional manifold. This is illustrated in Figure 20.29c, where we 
apply PCA (Section 20.1) to project a dataset of 360 images, each one a slightly rotated version 
of the digit 6, into a 2d space. We see that most of the variation in the data is captured by an 
underlying curved 2d manifold. We say that the intrinsic dimensionality d of the data is 2, even 
though the ambient dimensionality D is 3,648. 


20.4.3 Approaches to manifold learning 


In the rest of this section, we discuss ways to learn manifolds from data. There are many different 
algorithms that have been proposed, which make different assumptions about the nature of the 
manifold, and which have different computational properties. We discuss a few of these methods in 
the following sections. For more details, see e.g., [Bur10]. 

The methods can be categorized as shown in Table 20.1. The term “nonparametric” refers to 
methods that learn a low dimensional embedding z; for each datapoint x;, but do not learn a mapping 
function which can be applied to an out-of-sample datapoint. (However, [Ben+04b] discusses how to 
extend many of these methods beyond the training set by learning a kernel.) 

In the sections below, we compare some of these methods using 2 different datasets: a set of 
1000 3d-points sampled from the 2d “Swiss roll’ manifold, and a set of 1797 64-dimensional points 
sampled from the UCI digits dataset. See Figure 20.30 for an illustration of the data. We will learn 
a 2d manifold, so we can visualize the data. 
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Method Parametric Convex Section 

PCA / classical MDS N Y (Dense) Section 20.1 
Kernel PCA N Y (Dense) Section 20.4.6 
Isomap N Y (Dense) Section 20.4.5 
LLE N Y (Sparse) Section 20.4.8 
Laplacian Eigenmaps N Y (Sparse) Section 20.4.9 
tSNE N N Section 20.4.10 
Autoencoder Y N Section 20.3 


Table 20.1: A list of some approaches to dimensionality reduction. If a method is convex, we specify in 
parentheses whether it requires solving a sparse or dense eigenvalue problem. 
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Figure 20.30: Illustration of some data generated from low-dimensional manifolds. (a) The 2d Swiss-roll 
manifold embedded into 3d. Generated by manifold_ swiss_ sklearn.ipynb. (b) Sample of some UCI digits, 
which have size 8 x 8 = 64. Generated by manifold_ digits sklearn.ipynb. 


20.4.4 Multi-dimensional scaling (MDS) 


The simplest approach to manifold learning is multidimensional scaling (MDS). This tries to 
find a set of low dimensional vectors {z; € RË : i = 1: N} such that the pairwise distances between 
these vectors is as similar as possible to a set of pairwise dissimilarities D = {d;;} provided by the 
user. There are several variants of MDS, one of which turns out to be equivalent to PCA, as we 
discuss below. 


20.4.4.1 Classical MDS 


Suppose we start an N x D data matrix X with rows x;. Let us define the centered Gram (similarity) 
matrix as follows: 


Kij = (£i — T, £j — T) (20.104) 


In matrix notation, we have K= xx", where X = CyX and Cy = Iy — wiwlhy is the centering 
matrix. 
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Now define the strain of a set of embeddings as follows: 


Letrain(Z) = X (Kaj — (Zi 2)? = |K — ZZ" ||% (20.105) 
ij 

where Žž; = zi — Z is the centered embedding vector. Intuitively this measures how well similarities in 
the high-dimensional data space, Kij, are matched by similarities in the low-dimensional embedding 
space, (Ž;, Ž;}. Minimizing this loss is called classical MDS. 

We know from Section 7.5 that the best rank L approximation to a matrix is its truncated SVD 
representation, K = USV". Since K is positive semi definite, we have that V = U. Hence the 
optimal embedding satisfies 


ZZ" = USU” = (US2)(S?U") (20.106) 


Thus we can set the embedding vectors to be the rows of Ž = US?. 

Now we describe how to apply classical MDS to a dataset where we just have Euclidean distances, 
rather than raw features. First we compute a matrix of squared Euclidean distances, D?) = D © D, 
which has the following entries: 


DG) = |æ; — æ]? = |æ: — Il? + |z; — E|? — 2a; - T, £; — z) (20.107) 
= ||æ; — ||? + |la; — T|? — 2K; (20.108) 
We see that D) only differs from K by some row and column constants (and a factor of -2). Hence 


we can compute K by double centering D®) using Equation (7.89) to get K = —1CyD®)Cy. In 
other words, 


1 i 1 Š e 
E 2 2 2 2 
Kij = 5 (a N D diy N 3 dj, + N2 X X fn] (20.109) 


l=1 m=1 


We can then compute the embeddings as before. 7 
It turns out that classical MDS is equivalent to PCA (Section 20.1). To see this, let K = Uz Sz UẸ 
be the rank L truncated SVD of the centered kernel matrix. The MDS embedding is given by 


Zups = U_S?. Now consider the rank L SVD of the centered data matrix, X=Ux SxV\. The 
PCA embedding is Zpca = UxSx. Now 


K = XX! = UxSxV&VxSxUk = Ux82U, = US, Ut (20.110) 
Hence Ux = Uz and Sx = S2, and so Zeca = ZMDS- 


20.4.4.2 Metric MDS 


Classical MDS assumes Euclidean distances. We can generalize it to allow for any dissimilarity 
measure by defining the stress function 


Licz ldig — di)? 
2 


Lstress(Z) = (20.111) 
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Figure 20.81: Metric MDS applied to (a) Swiss roll. Generated by manifold_ swiss sklearn.ipynb. (b) UCI 
digits. Generated by manifold_ digits sklearn.ipynb. 


where di; = ||z; — z,||. This is called metric MDS. Note that this is a different objective than the 
one used by classical MDS, so even if dj; are Euclidean distances, the results will be different. 

We can use gradient descent to solve the optimization problem. However, it is better to use a 
bound optimization algorithm (Section 8.7) called SMACOF [|Lee77], which stands for “Scaling 
by MAjorizing a COmplication Function”. (This is the method implemented in scikit-learn.) See 
Figure 20.31 for the results of applying this to our running example. 


20.4.4.3  Non-metric MDS 


Instead of trying to match the distance between points, we can instead just try to match the ranking 
of how similar points are. To do this, let f(d) be a monotonic transformation from distances to ranks. 
Now define the loss 


ies (f (dig) — diz)? 
Xy d;, 


where d;; = ||z; — z,||. Minimizing this is known as non-metric MDS. 

This objective can be optimized iteratively. First the function f is optimized, for a given Z, using 
isotonic regression; this finds the optimal monotonic transformation of the input distances to match 
the current embedding distances. Then the embeddings Z are optimized, for a given f, using gradient 
descent, and the process repeats. 


Lym (Z) = (20.112) 


20.4.4.4 Sammon mapping 


Metric MDS tries to minimize the sum of squared distances, so it puts the most emphasis on large 
distances. However, for many embedding methods, small distances matter more, since they capture 
local structure. One way to capture this is to divide each term of the loss by d;j, so small distances 
get upweighted: 


1 (dij — dij)? 
Lesammon(Z) = E J (20.113) 
& z) > dij 
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Figure 20.32: (a) If we measure distances along the manifold, we find d(1,6) > d(1, 4), whereas if we measure 
in ambient space, we find d(1,6) < d(1,4). The plot at the bottom shows the underlying 1d manifold. (b) The 
K-nearest neighbors graph for some datapoints; the red path is the shortest distance between A and B on this 
graph. From [Hin18]. Used with kind permission of Geoff Hinton. 


Minimizing this results in a Sammon mapping. (The coefficient in front of the sum is just to 
simplify the gradient of the loss.) Unfortunately this is a non-convex objective, and it arguably puts 
too much emphasis on getting very small distances exactly right. We will discuss better methods for 
capturing local structure later on. 


20.4.5 Isomap 


If the high-dimensional data lies on or near a curved manifold, such as the Swiss roll example, then 
MDS might consider two points to be close even if their distance along the manifold is large. This is 
illustrated in Figure 20.32a. 

One way to capture this is to create the K-nearest neighbor graph between datapoints’, and then 
approximate the manifold distance between a pair of points by the shortest distance along this graph; 
this can be computed efficiently using Dijkstra’s shortest path algorithm. See Figure 20.32b for an 
illustration. Once we have computed this new distance metric, we can apply classical MDS (i.e., 
PCA). This is a way to capture local structure while avoiding local optima. The overall method is 
called isomap [TSLOO]. 

See Figure 20.33 for the results of this method on our running example. We see that they are 
quite reasonable. However, if the data is noisy, there can be “false” edges in the nearest neighbor 
graph, which can result in “short circuits” which significantly distort the embedding, as shown in 
Figure 20.34. This problem is known as “topological instability” [BS02]. Choosing a very small 
neighborhood does not solve this problem, since this can fragment the manifold into a large number 
of disconnected regions. Various other solutions have been proposed, e.g., [CC07]. 


20.4.6 Kernel PCA 


PCA (and classical MDS) finds the best linear projection of the data, so as to preserve pairwise 
similarities between all the points. In this section, we consider nonlinear projections. The key idea 


5. In scikit-learn, you can use the function sklearn.neighbors.kneighbors_graph. 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


20.4. Manifold learning * 693 


manifold-swiss-noise-0-lsomap 


Figure 20.33: Isomap applied to (a) Swiss roll. Generated by manifold_ swiss  sklearn.ipynb. (b) UCI digits. 
Generated by manifold_ digits_ sklearn.ipynb. 
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Figure 20.34: (a) Noisy version of Swiss roll data. We perturb each point by adding N (0,0.5°) noise. (b) 
Results of Isomap applied to this data. Generated by manifold_ swiss_ sklearn.ipynb. 


is to solve PCA by finding the eigenvectors of the inner product (Gram) matrix K = XX", as in 
Section 20.1.3.2, and then to use the kernel trick (Section 17.3.4), which lets us replace inner products 
such as x! 2; with a kernel function, Ki; = K(æ®;, £p). This is known as kernel PCA [SSM98]. 

Recall from Mercer’s theorem that the use of a kernel implies some underlying feature space, so we 
are implicitly replacing x; with @(x;) = @,;. Let ® be the corresponding (notional) design matrix, 
and K = XX" be the Gram matrix. Finally, let Sg = 4 X; pipl be the covariance matrix in feature 
space. (We are assuming for now the features are centered.) From Equation (20.22), the normalized 
eigenvectors of S are given by Vkpca = TUA, where U and A contain the eigenvectors and 
eigenvalues of K. Of course, we can’t actually compute Vipca, since @; is potentially infinite 
dimensional. However, we can compute the projection of a test vector xz, onto the feature space as 
follows: 

pl Vivca = $1 8'UA-? = klUA `? (20.114) 


* 


where k, = [K(a@.,21),...,K (£4, £N )]. 
There is one final detail to worry about. The covariance matrix is only given by S = P'S if the 
features is zero-mean. Thus we can only use the Gram matrix K = 6®! if E [ġ;] = 0. Unfortunately, 
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Figure 20.85: Visualization of the first 8 kernel principal component basis functions derived from some 2d 
data. We use an RBF kernel with o° = 0.1. Generated by kpcaScholkopf.ipynb. 


we cannot simply subtract off the mean in feature space, since it may be infinite dimensional. However, 
there is a trick we can use. Define the centered feature vector as 6; = $(a@;) — + De o(a,;). The 


Gram matrix of the centered feature vectors is given by K ij = p. d;.- Using the double centering trick 


from Equation (7.89), we can write this in matrix form as K = CyKCw, where Cy £ Iy — win 14 
is the centering matrix. 

If we apply kPCA with a linear kernel, we recover regular PCA (classical MDS). This is limited 
to using L < D embedding dimensions. If we use a non-degenerate kernel, we can use up to N 
components, since the size of ® is N x D*, where D* is the (potentially infinite) dimensionality of 
embedded feature vectors. Figure 20.35 gives an example of the method applied to some D = 2 
dimensional data using an RBF kernel. We project points in the unit grid onto the first 8 components 
and visualize the corresponding surfaces using a contour plot. We see that the first two components 
separate the three clusters, and the following components split the clusters. 

See Figure 20.36 for some the results on kPCA (with an RBF kernel) on our running example. In 
this case, the results are arguably not very useful. In fact, it can be shown that kPCA with an RBF 
kernel expands the feature space instead of reducing it [WSS04], as we saw in Figure 20.35, which 
makes it not very useful as a method for dimensionality reduction. We discuss a solution to this in 
Section 20.4.7. 


20.4.7 Maximum variance unfolding (MVU) 


kPCA with certain kernels, such as RBF, might not result in a low dimensional embedding, as 
discussed in Section 20.4.6. This observation led to the development of the semidefinite embedding 
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Figure 20.386: Kernel PCA applied to (a) Swiss roll. Generated by manifold_ swiss sklearn.ipynb. (b) UCI 
digits. Generated by manifold_ digits sklearn.ipynb. 


algorithm [WSS04], also called maximum variance unfolding, which tries to learn an embedding 
{z;} such that 


max } |z — zll? s-t. ||zi — zll? = læ: — æ;llĝ for all (i,j) € G (20.115) 
ij 
where G is the nearest neighbor graph (as in Isomap). This approach explicitly tries to ’unfold’ the 
data manifold while respecting the nearest neighbor constraints. 


This can be reformulated as a semidefinite programming (SDP) problem by defining the kernel 
matrix K = ZZ" and then optimizing 


maxtr(K) s.t. ||zi— zj? = læ: —@,||3, XO Ki; =0, K>0 (20.116) 
aj 
The resulting kernel is then passed to kPCA, and the resulting eigenvectors give the low dimensional 
embedding. 


20.4.8 Local linear embedding (LLE) 


The techniques we have discussed so far all rely on an eigendecomposition of a full matrix of pairwise 
similarities, either in the ambient space (PCA), in feature space (kPCA), or along the KNN graph 
(Isomap). In this section, we discuss local linear embedding (LLE) [RS00], a technique that solves 
a sparse eigenproblem, thus focusing more on local structure in the data. 

LLE assumes the data manifold around each point æ; is locally linear. The best linear approximation 
can be found by predicting x; as a linear combination of its K nearest neighbors using reconstruction 
weights w;. This can be found by solving 


N N 
i=1 j=l 
i ~ if j iK 
T a re A] (20.118) 
jai Wig =1 fori=1:N 
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Figure 20.37: LLE applied to (a) Swiss roll. Generated by manifold_swiss_ sklearn.ipynb. (b) UCI digits. 
Generated by manifold_ digits_ sklearn.ipynb. 


Note that we need the sum-to-one constraint on the weights to prevent the trivial solution W = 0. 
The resulting vector of weights w;,, constitute the barycentric coordinates of 2;. 

Any linear mapping of this hyperplane to a lower dimensional space preserves the reconstruction 
weights, and thus the local geometry. Thus we can solve for the low-dimensional embeddings for 
each point by solving 


N 
Ż = argmin > ||z: -X wi 2,113 (20.119) 
i j=l 


where w;; = 0 if j is not one of the K nearest neighbors of i. We can rewrite this loss as 
L(Z) = ||Z—-Wz|? = Z' (1 - W) (1 - W)Z (20.120) 


Thus the solution is given by the eigenvectors of (I — W)! (I — W) corresponding to the smallest 
nonzero eigenvalues, as shown in Section 7.4.8. 

See Figure 20.37 for some the results on LLE on our running example. In this case, the results do 
not seem as good as those produced by Isomap. However, the method tends to be somewhat less 
sensitive to short-circuiting (noise). 


20.4.9 Laplacian eigenmaps 


In this section, we describe Laplacian eigenmaps or spectral embedding [BN0O1]. The idea is 
to compute a low-dimensional representation of the data in which the weighted distances between 
a datapoint and its K nearest neighbors are minimized. We put more weight on the first nearest 
neighbor than the second, etc. We give the details below. 

20.4.9.1 Using eigenvectors of the graph Laplacian to compute embeddings 

We want to find embeddings which minimize 


LIZ)= Š, Wigllei— zll (20.121) 
(i,j)EE 
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Figure 20.38: Laplacian eigenmaps applied to (a) Swiss roll. Generated by manifold_swiss_ sklearn.ipynb. 
(b) UCI digits. Generated by manifold _digits_ sklearn.ipynb. 


where W;; = exp(—5+2||a,; — x,||3) if i — j are neighbors in the KNN graph and 0 otherwise. We 
add the constraint Z DZ = I to avoid the degenerate solution where Z = 0, where D is the diagonal 
weight matrix storing the degree of each node, Di; = X j Wij 

We can rewrite the above objective as follows: 


L(Z) = X Wa (lleill? + Ilzi? — 227 25) (20.122) 
ij 
= X Dyillzil|? + XC Dyllz; l? -2X0 Wijziz] (20.123) 
i 3 $3 
= 2tr(Z'DZ) — 2tr(Z'WZ) = 2tr(Z'LZ) (20.124) 


where L = D — W is the graph Laplacian (see Section 20.4.9.2). One can show that minimizing this 
is equivalent to solving the (generalized) eigenvalue problem Lz; = A;Dz; for the L smallest nonzero 
eigenvalues. 

See Figure 20.38 for the results of applying this method (with an RBF kernel) to our running 
example. 


20.4.9.2 What is the graph Laplacian? 


We saw above that we can compute the eigenvectors of the graph Laplacian in order to learn a good 
embedding of the high dimensional points. In this section, we give some intuition as to why this 
works. 

Let W be a symmetric weight matrix for a graph, where W;; = W;; > 0. Let D = diag(d;) be a 
diagonal matrix containing the weighted degree of each node, d; = > ; Wij- We define the graph 
Laplacian as follows: 


LêD-W (20.125) 
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2 00 0 0 0 010010 2 -1 0 Ql 0 
(6) 030000 O E E 1 3 1 0 1 0 
(4) (5) (1) 002 00 0 Oe 0 a0 0) 0 -1 2 —1 0 0 
ae, 000 3 0 0 0 0 1 0 1 1 0 0 -1 3 -1 -1 
& (2) 000030 1103100 = -I 01 3 0 
00000 1 00010 0 0 0 0 -i1 0 


Figure 20.39: Illustration of the Laplacian matrix derived from an undirected graph. From https: //en. 
wikipedia. org/wiki/Laplacian_matriz. Used with kind permission of Wikipedia author AzaToth. 


an 


Figure 20.40: Illustration of a (positive) function defined on a graph. From Figure 1 of [Shu+13]. Used with 
kind permission of Pascal Frossard. 


Thus the elements of L are given by 


dj ifi=j 
0 otherwise 


See Figure 20.39 for an example of how to compute this. 

Suppose we associate a value f; € R with each node i in the graph (see Figure 20.40 for example). 
Then we can use the graph Laplacian as a difference operator, to compute a discrete derivative of 
the function at a point: 


(Lf) = X Wulfi- FO) (20.127) 


jEnbr; 
where nbr; is the set of neighbors of node i. We can also compute an overall measure of “smoothness” 
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of the function f by computing its Dirichlet energy as follows: 


f'Lf =f" Df — f' Wf = 5 dif? — 5 fifjWwij (20.128) 
% ig 
1 1 
~3 dedi 2) | fifiwis + Daf} = 5 wil fi h) (20.129) 
i ij j ij 


By studying the eigenvalues and eigenvectors of the Laplacian matrix, we can determine various 
useful properties of the function. (Applying linear algebra to study the adjacency matrix of a 
graph, or related matrices, is called spectral graph theory [Chu97].) For example, we see that 
L is symmetric and positive semi-definite, since we have f7 Lf > 0 for all f € R^, which follows 
from Equation (20.129) due to the assumption that w;; > 0. Consequently L has N non-negative, 
real-valued eigenvalues, 0 < Ay < Ag < ... < Ay. The corresponding eigenvectors form an orthogonal 
basis for the function f defined on the graph, in order of decreasing smoothness. 

In Section 20.4.9.1, we discuss Laplacian eigenmaps, which is a way to learn low dimensional 
embeddings for high dimensional data vectors. The approach is to let zia = ff be the d’th embedding 
dimension for input i, and then to find a basis for these functions (i.e., embedding of the points) that 
varies smoothly over the graph, thus respecting distance of the points in ambient space. 

There are many other applications of the graph Laplacian in ML. For example, in Section 21.5.1, we 
discuss normalized cuts, which is a way to learn a clustering of high dimensional data vectors based 
on pairwise similarity; and [WTN19] discusses how to use the eigenvectors of the state transition 
matrix to learn representations for RL. 


20.4.10 t-SNE 


In this section, we describe a very popular nonconvex technique for learning low dimensional 
embeddings called t-SNE [MH08]. This extends the earlier stochastic neighbor embedding 
method of [HR03], so we first describe SNE, before describing the t-SNE extension. 


20.4.10.1 Stochastic neighborhood embedding (SNE) 


The basic idea in SNE is to convert high-dimensional Euclidean distances into conditional probabilities 
that represent similarities. More precisely, we define pj); to be the probability that point 7 would pick 
point j as its neighbor if neighbors were picked in proportion to their probability under a Gaussian 
centered at £;i: 


exp(- zile: — |?) 


rzi exp(- zzz lle: — wa |?) 


Pili = (20.130) 


Here ø? is the variance for data point i, which can be used to “magnify” the scale of points in dense 
regions of input space, and diminish the scale in sparser regions. (We discuss how to estimate the 
length scales o? shortly). 

Let z; be the low dimensional embedding representing x;. We define similarities in the low 
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dimensional space in an analogous way: 


exp(—||zi — zll?) 
deg EXP(—||2% — 2el|?) 


Gli = (20.131) 


In this case, the variance is fixed to a constant; changing it would just rescale the learned map, and 
not change its topology. 

If the embedding is a good one, then q,); should match p,);. Therefore, SNE defines the objective 
to be 


L=) Da (Pi || Qi) = D7 Do p 


where P; is the conditional distribution over all other data points given x;, Q; is the conditional 
distribution over all other latent points given z;, and Dt (P; || Q:) is the KL divergence (Section 6.2) 
between the distributions. 

Note that this is an asymmetric objective. In particular, there is a large cost if a small q;); is used 
to model a large pj; This objective will prefer to pull distant points together rather than push 
nearby points apart. We can get a better idea of the geometry by looking at the gradient for each 
embedding vector, which is given by 


(20.132) 


V L(Z) = = 22s )(Pjli Weg Dig — ilj) (20.133) 


Thus points are pulled towards each other if the p’s are bigger than the q’s, and repelled if the q’s 
are bigger than the p’s. 

Although this is an intuitively sensible objective, it is not convex. Nevertheless it can be minimized 
using SGD. In practice, it helps to add Gaussian noise to the embedding points, and to gradually 
anneal the amount of noise. [Hin13] recommends to “spend a long time at the noise level at which 
the global structure starts to form from the hot plasma of map points” before reducing it.° 


20.4.10.2 Symmetric SNE 


There is a slightly simpler version of SNE that minimizes a single KL between the joint distribution 
P in high dimensional space and Q in low dimensional space: 


Pij 
L = Drg (P || Q) =X pj ea, i (20.134) 
i<j 
This is called symmetric SNE. 
The obvious way to define p;; is to use 
exp(—z,2||xi — æ;|l?) 


Drc expl- zalle — æl?) 


Pij = (20.135) 


6. See [Ros98; WF20] for a discussion of annealing and phase transitions in unsupervised learning. See also [CP10] 
for a discussion of the elastic embedding algorithm, which uses a homotopy method to more efficiently optimize a 
model that is related to both SNE and Laplacian eigenmaps. 
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We can define qij similarily. 
The corresponding gradient becomes 


Vz,L(Z) = 2d — i) (Pig — qij) (20.136) 


As before, points are pulled towards each other if the p’s are bigger than the q’s, and repelled if the 
q’s are bigger than the p’s. 

Although symmetric SNE is slightly easier to implement, it loses the nice property of regular SNE 
that the data is its own optimal embedding if the embedding dimension L is set equal to the ambient 
dimension D. Nevertheless, the methods seems to give similar results in practice on real datasets 
where L < D. 


20.4.10.3 t-distributed SNE 


A fundamental problem with SNE and many other embedding techniques is that they tend to squeeze 
points that are relatively far away in the high dimensional space close together in the low dimensional 
(usually 2d) embedding space; this is called the crowding problem, and arises due to the use of 
squared errors (or Gaussian probabilities). 

One solution to this is to use a probability distribution in latent space that has heavier tails, 
which eliminates the unwanted attractive forces between points that are relatively far in the high 
dimensional space. An obvious choice is the Student-t distribution (Section 2.7.1). In t-SNE, they 
set the degree of freedom parameter to v = 1, so the distribution becomes equivalent to a Cauchy: 


ng = tla- al? 
i Meal + lle. — zll?) 


We can use the same global KL objective as in Equation (20.134). For t-SNE, the gradient turns 
out to be 


(20.137) 


Vail =4)— (pig — a4) (i — 25)(1 + Iles — zll) (20.138) 
J 


The gradient for symmetric (Gaussian) SNE is the same, but lacks the (1+ ||z; — z;||?)~' term. This 
term is useful because (1 + ||z; — z;||?)~1 acts like an inverse square law. This means that points in 
embedding space act like stars and galaxies, forming many well-separated clusters (galaxies) each of 
which has many stars tightly packed inside. This can be useful for separating different classes of data 
in an unsupervised way (see Figure 20.41 for an example). 


20.4.10.4 Choosing the length scale 


An important parameter in t-SNE is the local bandwidth o?. This is usually chosen so that P; has a 
perplexity chosen by the user.’ This can be interpreted as a smooth measure of the effective number 
of neighbors. 


7. The perplexity is defined to be 2H(Pi), where H(P;) = — D Pjli 1082 Pjļi is the entropy; see Section 6.1.5 for details. 
A big radius around each point (large value of o;) will result in a high entropy, and thus high perplexity. 
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Figure 20.41: tSNE applied to (a) Swiss roll. Generated by manifold_ swiss sklearn.ipynb. (b) UCI digits. 
Generated by manifold_ digits_ sklearn.ipynb. 
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Figure 20.42: Illustration of the effect of changing the perplexity parameter when t-SNE is applied to some 
2d data. From [WV J16]. See http: // distill. pub/2016/misread-tsne for an animated version of these 
figures. Used with kind permission of Martin Wattenberg. 


Unfortunately, the results of t-SNE can be quite sensitive to the perplexity parameter, so it is 
wise to run the algorithm with many different values. This is illustrated in Figure 20.42. The input 
data is 2d, so there is no distortion generating by mapping to a 2d latent space. If the perplexity 
is too small, the method tends to find structure within each cluster which is not truly present. At 
perplexity 30 (the default for scikit-learn), the clusters seem equi-distant in embedding space, even 
though some are closer than others in the data space. Many other caveats in interpreting t-SNE 
plots can be found in [WVJ16]. 


20.4.10.5 Computational issues 


The naive implementation of t-SNE takes O(N?) time, as can be seen from the gradient term in 
Equation (20.138). A faster version can be created by leveraging an analogy to N-body simulation in 
physics. In particular, the gradient requires computing the force of N points on each of N points. 
However, points that are far away can be grouped into clusters (computationally speaking), and 
their effective force can be approximated by a few representative points per cluster. We can then 
approximate the forces using the Barnes-Hut algorithm [BH86], which takes O(N log N) time, as 
proposed in [Maal4]. Unfortunately, this only works well for low dimensional embeddings, such as 
L=2. 
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20.4.10.6 UMAP 


Various extensions of tSNE have been proposed, that try to improve its speed, the quality of the 
embedding space, or the ability to embed into more than 2 dimensions. 

One popular recent extension is called UMAP (which stands for “Uniform Manifold Approximation 
and Projection”), was proposed in [MHM18]. At a high level, this is similar to tSNE, but it tends to 
preserve global structure better, and it is much faster. This makes it easier to try multiple values of 
the hyperparameters. For an interactive tutorial on UMAP, and a comparison to tSNE, see [CP 19]. 


20.5 Word embeddings 


Words are categorical random variables, so their corresponding one-hot vector representations are 
sparse. The problem with this binary representation is that semantically similar words may have 
very different vector representations. For example, the pair of related words “man” and “woman” will 
be Hamming distance 1 apart, as will the pair of unrelated words “man” and “banana”. 

The standard way to solve this problem is to use word embeddings, in which we map each sparse 
one-hot vector, 8,4 € {0, 1}™ , representing the ¢’th word in document n, to a lower-dimensional dense 
vector, Zn, E€ R?, such that semantically similar words are placed close by. This can significantly 
help with data sparsity. There are many ways to learn such embeddings, as we discuss below. 

Before discussing methods, we have to define what we mean by “semantically similar” words. We 
will assume that two words are semantically similar if they occur in similar contexts. This is known 
as the distributional hypothesis [Har54], which is often summarized by the phase (originally from 
[Fir57|) “a word is characterized by the company it keeps”. Thus the methods we discuss will all 
learn a mapping from a word’s context to an embedding vector for that word. 


20.5.1 Latent semantic analysis / indexing 


In this section, we discuss a simple way to learn word embeddings based on singular value decompo- 
sition (Section 7.5) of a term-frequency count matrix. 


20.5.1.1 Latent semantic indexing (LSI) 


Let Ci; be the number of times “term” 7 occurs in “context” j. The definition of what we mean by 
“term” is application-specific. In English, we often take it to be the set of unique tokens that are 
separated by punctuation or whitespace; for simplicity, we will call these “words”. However, we may 
preprocess the text data to remove very frequent or infrequent words, or perform other kinds of 
preprocessing. as we discuss in Section 1.5.4.1. 

The definition of what we mean by “context” is also application-specific. In this section, we 
count how many times word i occurs in each document j € {1,...,N} from a set or corpus of 
documents; the resulting matrx C is called a term-document frequency matrix, as in Figure 1.15. 
(Sometimes we apply the TF-IDF transformation to the counts, as discussed in Section 1.5.4.2.) 

Let C € R“*% be the count matrix, and let C be the rank K approximation that minimizes the 
following loss: 


£(C) = ||C - Cllr = X (Cy — Ĉi)? (20.139) 


tj 
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Figure 20.48: Illustration of the cosine similarity between a query vector q and two document vectors dı 
and dz. Since angle a is less than angle 0, we see that the query is more similar to document 1. From 
https: //en. wikipedia. org/wiki/Vector_ space_ model. Used with kind permission of Wikipedia author 
Riclas. 


One can show that the minimizer of this is given by the rank K truncated SVD approximation, 
C = USV. This means we can represent each cij as a bilinear product: 


K 
Cig © 5 UikSkVjk (20.140) 
k=1 


We define u; to be the embedding for word i, and s © v; to be the embedding for context j. 

We can use these embeddings for document retrieval. The idea is to compute an embedding for 
the query words using u;, and to compare this to the embedding of all the documents or contexts vj. 
This is known as latent semantic indexing or LSI [Dee+90]. 

In more detail, suppose the query is a bag of words w1,..., wp; we represent this by the vector 
q= $ De Uw,, Where Uw, is the embedding for word wy. Let document j be represented by vj. 
We then rank documents by the cosine similarity between the query vector and document, defined 
by 


qid 


sim(q, d) = ————_ 
(24) = aiai 


(20.141) 


where ||q|| = V}; g@ is the 42-norm of q. This measures the angles between the two vectors, as 
shown in Figure 20.43. Note that if the vectors are unit norm, cosine similarity is the same as inner 
product; it is also equal to the squared Euclidean distance, up to a change of sign and an irrelevant 
additive constant: 


lq — ||? = (q - d)' (q - d) = q'q + d'd — 2q' d = 2(1 — sim(q, d)) (20.142) 


20.5.1.2 Latent semantic analysis (LSA) 


Now suppose we define context more generally to be some local neighborhood of words j € {1,..., M”}, 
where A is the window size. Thus Cj; is how many times word 7 occurs in a neighborhood of type j. 
We can compute the SVD of this matrix as before, to get cj; ~ So. UikSkVjk- We define u; to be 
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the embedding for word 7, and s© v; to be the embedding for context j. This is known as latent 
semantic analysis or LSA [Dee+90]. 

For example, suppose we compute C on the British National Corpus.° For each word, let us 
retrieve the K nearest neighbors in embedding space ranked by cosine similarity (i.e., normalized 
inner product). If the query word is “dog”, and we use h = 2 or h = 30, the nearest neighbors are as 
follows: 


8 


h=2: cat, horse, fox, pet, rabbit, pig, animal, mongrel, sheep, pigeon 
h=30: kennel, puppy, pet, bitch, terrier, rottweiler, canine, cat, to bark 


The 2-word context window is more sensitive to syntax, while the 30-word window is more sensitive 
to semantics. The “optimal” value of context size h depends on the application. 
20.5.1.3 PMI 


In practice LSA (and other similar methods) give much better results if we replace the raw counts 
Ci; with pointwise mutual information (PMI) [CH90], defined as 


PMI(i, j) = log 262) (20.143) 
pip) 

If word i is strongly associated with context j, we will have PMI(i, j) > 0. If the PMI is negative, it 

means 7 and j co-occur less often that if they were independent; however, such negative correlations 

can be unreliable, so it is common to use the positive PMI: PPMI(i, j) = max(PMI(7,7),0). In 

[BLO7b], they show that SVD applied to the PPMI matrix results in word embeddings that perform 

well on a many tasks related to word meaning. See Section 20.5.5 for a theoretical model that explains 
this empirical performance. 


20.5.2 Word2vec 


In this section, we discuss the popular word2vec model from [Mik+13a; Mik+13b], which are 
“shallow” neural nets for predicting a word given its context. In Section 20.5.5, we will discuss the 
connections with SVD of the PMI matrix. 

There are two versions of the word2vec model. The first is called CBOW, which stands for 
“continuous bag of words”. The second is called skipgram. We discuss both of these below. 


20.5.2.1 Word2vec CBOW model 


In the continuous bag of words (CBOW) model (see Figure 20.44(a)), the log likelihood of a sequence 
of words is computed using the following model: 


T 


a Vy, Vt) 
log p(w) = > log p(wr|wi—m:t+m) log =———_*-—_ (20.144) 
2, -5 w EXp(vl U+) 
T 
= > Vip, Dt — log 5 exp(v, v) (20.145) 
t=1 iEeV 


8. This example is taken from [Eis19, p312]. 
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Figure 20.44: Illustration of word2vec model with window size of 2. (a) CBOW version. (b) Skip-gram 


version. 


where Vw, is the vector for the word at location w+, V is the set of all words, m is the context size, 
and 


T= z N Wen + Vwi-n) (20.146) 


h=1 


is the average of the word vectors in the window around word w+. Thus we try to predict each word 
given its context. The model is called CBOW because it uses a bag of words assumption for the 
context, and represents each word by a continuous embedding. 


20.5.2.2 Word2vec Skip-gram model 


In CBOW, each word is predicted from its context. A variant of this is to predict the context 
(surrounding words) given each word. This yields the following objective: 


T m 
-logp(w) = -X | So log p(we—j|we) + log p(w; |w) (20.147) 


t=1 | j=1 
T 

=-S> SE dog p(wes;|we) (20.148) 
t=1 —m<j<m,j40 


where m is the context window length. We define the log probability of some other context word we 
given the central word we to be 


log p(wo|we) = ulve — log (= esti) (20.149) 
iev 
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where VY is the vocabulary. Here u; is the embedding of a word if used as context, and v; is the 
embedding of a word if used as a central (target) word to be predicted. This model is known as the 
skipgram model. See Figure 20.44(b) for an illustration. 


20.5.2.3 Negative sampling 


Computing the conditional probability of each word using Equation (20.149) is expensive, due to the 
need to normalize over all possible words in the vocabulary. This makes it slow to compute the log 
likelihood and its gradient, for both the CBOW and skip-gram models. 

In [Mik+13b], they propose a fast approximation, called skip-gram with negative sampling 
(SGNS). The basic idea is to create a set of K +1 context words for each central word w+, and 
to label the one that actually occurs as positive, and the rest as negative. The negative words are 
called noise words, and can be sampled from a reweighted unigram distribution, p(w) œ freq(w)?/4, 
which has the effect of redistributing probability mass from common to rare words. The conditional 
probability is now approximated by 


K 
p(wi4j|we) = p(D = lr, wti) |] p(D = Owe, we) (20.150) 
k=1 


where wz ~ p(w) are noise words, and D = 1 is the event that the word pair actually occurs in the 
data, and D = 0 is the event that the word pair does not occur. The binary probabilities are given by 


p(D = 1lwr, wi4;) = O (thy, Vw, ) (20.151) 
p(D = 0lwz, we) = 1 — o (ul, Vw) (20.152) 

To train this model, we just need to compute the contexts for each central word, and a set of 
negative noise words. We associate a label of 1 with the context words, and a label of 0 with the 


noise words. We can then compute the log probability of the data, and optimize the embedding 
vectors u; and v; for each word using SGD. See skipgram_jax.ipynb for some sample code. 


20.5.3 GloVE 


A popular alternative to Skipgram is the GloVe model of [PSM14a]. (GloVe stands for “global 
vectors for word representation”.) This method uses a simpler objective, which is much faster to 
optimize. 

To explain the method, recall that in the skipgram model, the predicted conditional probability of 
word j occuring in the context window of central word 7 as 


_ exp(uj v;) 
kev exp(u}vi) 


Let x;j be the number of times word j occurs in any context window of i. (Note that if word i occurs 
in the window of j, then j will occur in the window of i, so we have x;; = zji.) Then we can rewrite 
Equation (20.148) as follows: 


L=- 5 5 Tij log dij (20.154) 


iEV jEV 


qij (20.153) 
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Figure 20.45: Visualization of arithmetic operations in word2vec embedding space. From https: // www. 
tensorflow. org/ tutorials/representation/ word2vec. 


If we define pij = £ij/£i to be the empirical probability of word j occuring in the context window of 
central word i, we can rewrite the skipgram loss as a cross entropy loss: 


L=-) >>> py log gi; (20.155) 


ieV jEV 


The problem with this objective is that computing qij is expensive, due to the need to normalize over 
all words. In GloVe, we work with unnormalized probabilities, p,; = xi; and qj; = exp(Uj v; +b;+6,), 
where b; and cj are bias terms to capture marginal probabilities. In addition, we minimize the 
squared loss, (log p;; — log dij) which is more robust to errors in estimating small probablities than 
log loss. Finally, we upweight rare words for which xij < c, where c = 100, by weighting the squared 
errors by h(aj;), where h(x) = (a/c) if x < c, and h(x) = 1 otherwise. This gives the final GloVe 
objective: 


L=—DS 7D) (wig) (uj vi + bi + cj — log zij)? (20.156) 
iEV jEV 


We can precompute «;; offline, and then optimize the above objective using SGD. After training, we 
define the embedding of word i to be the average of v; and u;. 

Empirically GloVe gives similar results to skigram, but it is faster to train. See Section 20.5.5 for a 
theoretical model that explains why these methods work. 


20.5.4 Word analogies 


One of the most remarkable properties of word embeddings produced by word2vec, GloVe, and other 
similar methods is that the learned vector space seems to capture relational semantics in terms 
of simple vector addition. For example, consider the word analogy problem “man is to woman 
as king is to queen”, often written as man:woman::king:queen. Suppose we are given the words 
a=man, b=woman, c=king; how do we find d=queen? Let 6 = Vp — Va be the vector representing the 
concept of “converting the gender from male to female”. Intuitively we can find word d by computing 
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va = c+ 6, and then finding the closest word in the vocabulary to vg. See Figure 20.45 for an 
illustration of this process, and word_analogies_jax.ipynb for some code. 

In [PSM14a], they conjecture that a: b:: c: d holds iff for every word w in the vocabulary, we 
have 


plwla) _ plwle) 
p(w|b) p(wld) 
In [Aro+16], they show that this follows from the RAND-WALK modeling assumptions in Sec- 


tion 20.5.5. See also [AH19; EDH19] for other explanations of why word analogies work, based on 
different modeling assumptions. 


(20.157) 


20.5.5 RAND-WALK model of word embeddings 


Word embeddings significantly improve the performance of various kinds of NLP models compared 
to using one-hot encodings for words. It is natural to wonder why the above word embeddings work 
so well. In this section, we give a simple generative model for text documents that explains this 
phenomenon, based on [Aro+16]. 

Consider a sequence of words w1,..., wr. We assume each word is generated by a latent context 
or discourse vector z; € RP using the following log bilinear language model, similar to [MH07]: 


ate: steed) 
Xy exp(z] Vu) Z(z:) 


where v,, € R? is the embedding for word w, and Z (z+) is the partition function. We assume D < M, 
the number of words in the vocabulary. 

Let us further assume the prior for the word embeddings v,, is an isotropic Gaussian, and that the 
latent topic z; undergoes a slow Gaussian random walk. (This is therefore called the RAND-WALK 
model.) Under this model, one can show that Z(z;) is approximately equal to a fixed constant, 
Z, independent of the context. This is known as the self-normalization property of log-linear 
models [AK15]. Furthermore, one can show that the pointwise mutual information of predictions 
from the model is given by 


p(w = w|z:) = (20.158) 


p(w, w’) pe V, Vw! 


p(w)p(w') D 
We can therefore fit the RAND-WALK model by matching the model’s predicted values for PMI 


with the empirical values, i.e., we minimize 


PMI(w, w’) = (20.159) 


L= X Xu w (PMI (w, w’) — v1, 00)? (20.160) 


ww! 


where Xw,w is the number of times w and w’ occur next to each other. This objective can be seen as 
a frequency-weighted version of the SVD loss in Equation (20.139). (See [LG14] for more connections 
between word embeddings and SVD.) 

Furthermore, some additional approximations can be used to show that the NLL for the RAND- 
WALK model is equivalent to the CBOW and SGNS word2vec objectives. We can also derive the 
objective for GloVE from this approach. 
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20.5.6 Contextual word embeddings 


Consider the sentences “I was eating an apple” and “I bought a new phone from Apple”. The 
meaning of the word “apple” is different in both cases, but a fixed word embedding, of the type 
discussed in Section 20.5, would not be able to capture this. In Section 15.7, we discuss contextual 
word embeddings, where the embedding of a word is a function of all the words in its context 
(usually a sentence). This can give much improved results, and is currently the standard approach 
to representing natural language data, as a pre-processing step before doing transfer learning (see 
Section 19.2). 


20.6 Exercises 


Exercise 20.1 [EM for FA] 

Derive the EM updates for the factor analysis model. For simplicity, you can optionally assume p = 0 is 
fixed. 

Exercise 20.2 [EM for mixFA *| 


Derive the EM updates for a mixture of factor analysers. 


Exercise 20.3 [Deriving the second principal component] 
a. Let 


n 


1 
J (v2, z2) = = X (z: Zil U1 zi2U2) (xi — 241U1 — 2i2U2) (20.161) 


i=1 


Show that oe = 0 yields zj2 = vI zi. 
b. Show that the value of v2 that minimizes 
J(v2) = —v3 Cuz + à2(v3 v2 — 1) + à12 (v3 v1 — 0) (20.162) 
is given by the eigenvector of C with the second largest eigenvalue. Hint: recall that Cui = A1v1 and 


Exercise 20.4 [Deriving the residual error for PCA *] 


a. Prove that 
K K 
|e: — 5 zijvj||? = af a, — 5 vj LiL; V; (20.163) 
j=1 j=1 


Hint: first consider the case K = 2. Use the fact that v vj = 1 and v7 Uk = 0 for k Æ j. Also, recall 
T 
Zij = Ti Uj. 


b. Now show that 


n K n K 
1 1 
Jg ê - ` Ge — ` ofa) = J rT x; — J Àj (20.164) 
j=1 ¿=i j=1 


i=l 


Hint: recall v? Cv; = AjV; vj = \j. 
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c. If K = d there is no truncation, so Ja = 0. Use this to show that the error from only using K < d terms 
is given by 


d 
J= X y (20.165) 
j=K+1 


Hint: partition the sum ee A; into Da A; and Mok dj. 


Exercise 20.5 [PCA via successive deflation] 


Let v1, V2,..., Uk be the first k eigenvectors with largest eigenvalues of C = XTX, where X is the centered 
N x D design matrix; these are known as the principal basis vectors. These satisfy 
r, _f 0 ifj#k 
vj Vk = { 1 GLE (20.166) 


We will construct a method for finding the v; sequentially. 


As we showed in class, vı is the first principal eigenvector of C, and satisfies Cvı = A1v1. Now define č; as 
the orthogonal projection of æ; onto the space orthogonal to v1: 


ži = Piv, x; = (I — viv} )a; (20.167) 


Define X = [£1;...; Zn] as the deflated matrix of rank d — 1, which is obtained by removing from the d 
dimensional data the component that lies in the direction of the first principal direction: 


X = (I — wv, )"X = (I — viv] )X (20.168) 


a. Using the facts that XTXv, = nv (and hence vIXTX = nrA101 ) and vi vy = 1, show that the 
covariance of the deflated matrix is given by 


XTX = “XX Avivi (20.169) 


b. Let u be the principal eigenvector of Č. Explain why u = v2. (You may assume u is unit norm.) 


c. Suppose we have a simple method for finding the leading eigenvector and eigenvalue of a pd matrix, 
denoted by [A, u] = f(C). Write some pseudo code for finding the first K principal basis vectors of X 
that only uses the special f function and simple vector arithmetic, i.e., your code should not use SVD or 
the eig function. Hint: this should be a simple iterative routine that takes 2-3 lines to write. The input 
is C, K and the function f, the output should be v; and à; for j=1: K. 


Exercise 20.6 [PPCA variance terms] 


Recall that in the PPCA model, C = WWT + o°I. We will show that this model correctly captures the 
variance of the data along the principal axes, and approximates the variance in all the remaining directions 
with a single average value o°. 


Consider the variance of the predictive distribution p(a) along some direction specified by the unit vector v, 
where vv = 1, which is given by v7 Cv. 


a. First suppose v is orthogonal to the principal subspace. and hence v’ U = 0. Show that v? Cv = 0°. 


b. Now suppose v is parallel to the principal subspace. and hence v = u; for some eigenvector u;i. Show 
that v7 Cv = (Ai — 0?) +0? = Ai. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


712 Chapter 20. Dimensionality Reduction 


Exercise 20.7 [Posterior inference in PPCA *] 


Derive p(Zn|a@n) for the PPCA model. 


Exercise 20.8 [Imputation in a FA model *] 


Derive an expression for p(xn|x1,@) for a FA model, where x = (£h, £v) is a partition of the data vector. 


Exercise 20.9 [Efficiently evaluating the PPCA density] 


Derive an expression for p(a|W, ô?) for the PPCA model based on plugging in the MLEs and using the 
matrix inversion lemma. 
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2 1 Clustering 


21.1 Introduction 


Clustering is a very common form of unsupervised learning. There are two main kinds of methods. 
In the first approach, the input is a set of data samples D = {£n : n = 1 : N}, where £n € Æ, where 
typically X = RP. In the second approach, the input is an N x N pairwise dissimilarity metric 
Dj; > 0. In both cases, the goal is to assign similar data points to the same cluster. 

As is often the case with unsupervised learning, it is hard to evaluate the quality of a clustering 
algorithm. If we have labeled data for some of the data, we can use the similarity (or equality) 
between the labels of two data points as a metric for determining if the two inputs “should” be 
assigned to the same cluster or not. If we don’t have labels, but the method is based on a generative 
model of the data, we can use log likelihood as a metric. We will see examples of both approaches 
below. 


21.1.1 Evaluating the output of clustering methods 


The validation of clustering structures is the most difficult and frustrating part of cluster 
analysis. Without a strong effort in this direction, cluster analysis will remain a black art 
accessible only to those true believers who have experience and great courage. — Jain and 
Dubes [J D838] 


Clustering is an unsupervised learning technique, so it is hard to evaluate the quality of the output 
of any given method [Kle02; LWG12]. If we use probabilistic models, we can always evaluate the 
likelihood of the data, but this has two drawbacks: first, it does not directly assess any clustering 
that is discovered by the model; and second, it does not apply to non-probabilistic methods. So now 
we discuss some performance measures not based on likelihood. 

Intuitively, the goal of clustering is to assign points that are similar to the same cluster, and to 
ensure that points that are dissimilar are in different clusters. There are several ways of measuring 
these quantities e.g., see [JD88; KR90]. However, these internal criteria may be of limited use. An 
alternative is to rely on some external form of data with which to validate the method. For example, 
if we have labels for each object, then we can assume that objects with the same label are similar. 
We can then use the metrics we discuss below to quantify the quality of the clusters. (If we do not 
have labels, but we have a reference clustering, we can derive labels from that clustering.) 
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Figure 21.1: Three clusters with labeled objects inside. 


21.1.1.1 Purity 


Let Nj; be the number of objects in cluster 7 that belong to class j, and let N; = ae, Ni; be the 
total number of objects in cluster i. Define pi; = Nj;/Ni; this is the empirical distribution over class 
labels for cluster i. We define the purity of a cluster as p; = max, pij, and the overall purity of a 
clustering as 


Ni 
purity £ 5 Wi (21.1) 


For example, in Figure 21.1, we have that the purity is 
OS mae 53  5+4+3 _ 
176 176 175° 17 


The purity ranges between 0 (bad) and 1 (good). However, we can trivially achieve a purity of 1 by 
putting each object into its own cluster, so this measure does not penalize for the number of clusters. 


0.71 (21.2) 


21.1.1.2 Rand index 


Let U = {u1,... ur} and V = {v1,..., vc} be two different partitions of the N data points. For 
example, U might be the estimated clustering and V is reference clustering derived from the class 
labels. Now define a 2 x 2 contingency table, containing the following numbers: TP is the number of 
pairs that are in the same cluster in both U and V (true positives); TN is the number of pairs that 
are in the different clusters in both U and V (true negatives); FN is the number of pairs that are in 
the different clusters in U but the same cluster in V (false negatives); and FP is the number of pairs 
that are in the same cluster in U but different clusters in V (false positives). A common summary 
statistic is the Rand index: 

A TP+TN 

— TP+FP+FN+TN 
This can be interpreted as the fraction of clustering decisions that are correct. Clearly 0< R < 1. 

For example, consider Figure 21.1, The three clusters contain 6, 6 and 5 points, so the number of 

“positives” (i.e., pairs of objects put in the same cluster, regardless of label) is 


TP+FP = (3) + (3) + (3) = 40 (21.4) 


Of these, the number of true positives is given by 


ORE as 
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where the last two terms come from cluster 3: there are (3) pairs labeled C and (3) pairs labeled 


A. So FP = 40 — 20 = 20. Similarly, one can show FN = 24 and TN = 72. So the Rand index is 
(20 + 72)/(20 + 20 + 24 + 72) = 0.68. 

The Rand index only achieves its lower bound of 0 if TP = TN = 0, which is a rare event. One 
can define an adjusted Rand index [HA85] as follows: 


index — expected index 


ARS (21.6) 


max index — expected index 


Here the model of randomness is based on using the generalized hyper-geometric distribution, i.e., 
the two partitions are picked at random subject to having the original number of classes and objects 
in each, and then the expected value of TP + TN is computed. This model can be used to compute 
the statistical significance of the Rand index. 

The Rand index weights false positives and false negatives equally. Various other summary statistics 
for binary decision problems, such as the F-score (Section 5.1.4), can also be used. 


21.1.1.3 Mutual information 


Another way to measure cluster quality is to compute the mutual information between two candidate 
partitions U and V, as proposed in [VD99]. To do this, let pyy (i, j) = lust wal be the probability that 
a randomly chosen object belongs to cluster u; in U and v; in V. Also, let py(i) = |ui|/N be the 
be the probability that a randomly chosen object belongs to cluster u; in U; define py (j) = |v,|/N 
similarly. Then we have 


ee puv (i, j) 
=> pev si) 08 O ed 


This lies between 0 and min{H(U),H(V)}. Unfortunately, the maximum value can be achieved 
by using lots of small clusters, which have low entropy. To compensate for this, we can use the 
normalized mutual information, 


I(U, V) 
(H (U) +H (V))/2 


NMI(U,V) ê (21.8) 


This lies between 0 and 1. A version of this that is adjusted for chance (under a particular random 
data model) is described in [VEB09]. Another variant, called variation of information, is described 
in [Mei05]. 


21.2 Hierarchical agglomerative clustering 
A common form of clustering is known as hierarchical agglomerative clustering or HAC. The 


input to the algorithm is an N x N dissimilarity matrix D,; > 0, and the output is a tree structure 
in which groups i and j with small disimilarity are grouped together in a hierarchical fashion. 
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(a) (b) 


Figure 21.2: (a) An example of single link clustering using city block distance. Pairs (1,3) and (4,5) are 
both distance 1 apart, so get merged first. (b) The resulting dendrogram. Adapted from Figure 7.5 of [Alp04]. 


Generated by agglomDemo.ipynb. 
(b) (c) 


Figure 21.3: Illustration of (a) Single linkage. (b) Complete linkage. (c) Average linkage. 


(a) 


For example, consider the set of 5 inputs points in Figure 21.2(a), £n € R?. We will use city 
block distance between the points to define the dissimilarity, i.e., 


2 
dij = 5 [Zik — Tjk] (21.9) 


We start with a tree with N leaves, each corresponding to a cluster with a single data point. Next 
we compute the pair of points that are closest, and merge them. We see that (1,3) and (4,5) are 
both distance 1 apart, so they get merged first. We then measure the dissimilarity between the sets 
{1,3}, {4,5} and {2} using some measure (details below), and group them, and repeat. The result is 
a binary tree known as a dendrogram, as shown in Figure 21.2(b). By cutting this tree at different 
heights, we can induce a different number of (nested) clusters. We give more details below. 


21.2.1 The algorithm 


Agglomerative clustering starts with N groups, each initially containing one object, and then at 
each step it merges the two most similar groups until there is a single group, containing all the data. 
See Algorithm 21.1 for the pseudocode. Since picking the two most similar clusters to merge takes 
O(N?) time, and there are O(N) steps in the algorithm, the total running time is O(N?). However, 
by using a priority queue, this can be reduced to O(N? log N) (see e.g., [MRS08, ch. 17] for details). 
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Algorithm 21.1: Agglomerative clustering 


Initialize clusters as singletons: for i 4+ 1 to n do C; + {i} 


1 
2 
3 Initialize set of clusters available for merging: S + {1,...,n}; repeat 
4 Pick 2 most similar clusters to merge: (j,k) + arg min; reg dj,k 

5 Create new cluster Ce + Cj U Ck 

6 Mark j and k as unavailable: S << S \ {j,k} 

7 if Ce A {1,...,n} then 

8 Mark £ as available, S + S U {4} 


foreach i € S do 
10 | Update dissimilarity matrix d(i, £) 


© 


11 until no more clusters are available for merging 


single link complete link average link 


“lalla Meia aE 


(a) () (c) 


Figure 21.4: Hierarchical clustering of yeast gene expression data. (a) Single linkage. (b) Complete linkage. 
(c) Average linkage. Generated by hclust_yeast_demo.ipynb. 


There are actually three variants of agglomerative clustering, depending on how we define the 
dissimilarity between groups of objects. We give the details below. 


21.2.1.1 Single link 


In single link clustering, also called nearest neighbor clustering, the distance between two 
groups G and H is defined as the distance between the two closest members of each group: 


dsL(G, H) = in diy 21.10 
aC) = n i (21.10) 
See Figure 21.3(a). 

The tree built using single link clustering is a minimum spanning tree of the data, which is a tree 
that connects all the objects in a way that minimizes the sum of the edge weights (distances). To 


see this, note that when we merge two clusters, we connect together the two closest members of 
the clusters; this adds an edge between the corresponding nodes, and this is guaranteed to be the 
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“lightest weight” edge joining these two clusters. And once two clusters have been merged, they will 
never be considered again, so we cannot create cycles. As a consequence of this, we can actually 
implement single link clustering in O(N?) time, whereas the other variants take O(N°) time. 


21.2.1.2 Complete link 


In complete link clustering, also called furthest neighbor clustering, the distance between 
two groups is defined as the distance between the two most distant pairs: 


dot (G, H) = ao (21.11) 
See Figure 21.3(b). 

Single linkage only requires that a single pair of objects be close for the two groups to be considered 
close together, regardless of the similarity of the other members of the group. Thus clusters can be 
formed that violate the compactness property, which says that all the observations within a group 
should be similar to each other. In particular if we define the diameter of a group as the largest 
dissimilarity of its members, dg = maxjeq,ieg dis’, then we can see that single linkage can produce 
clusters with large diameters. Complete linkage represents the opposite extreme: two groups are 
considered close only if all of the observations in their union are relatively similar. This will tend 


to produce clusterings with small diameter, i.e., compact clusters. (Compare Figure 21.4(a) with 
Figure 21.4(b).) 


21.2.1.3 Average link 


In practice, the preferred method is average link clustering, which measures the average distance 
between all pairs: 


dang(G, H) = : So diw (21.12) 


nen 
GUH GEG ieEH 


where ng and ny are the number of elements in groups G and H. See Figure 21.3(c). 

Average link clustering represents a compromise between single and complete link clustering. It 
tends to produce relatively compact clusters that are relatively far apart. (See Figure 21.4(c).) 
However, since it involves averaging of the d;,;’s, any change to the measurement scale can change the 
result. In contrast, single linkage and complete linkage are invariant to monotonic transformations of 
dii, since they leave the relative ordering the same. 


21.2.2 Example 


Suppose we have a set of time series measurements of the expression levels for N = 300 genes at 
T = 7 points. Thus each data sample is a vector x, € R7. See Figure 21.5 for a visualization of the 
data. We see that there are several kinds of genes, such as those whose expression level goes up 
monotonically over time (in response to a given stimulus), those whose expression level goes down 
monotonically, and those with more complex response patterns. 

Suppose we use Euclidean distance to compute a pairwise dissimilarity matrix, D € and 
apply HAC using average linkage. We get the dendrogram in Figure 21.6(a). If we cut the tree at 


300 x 300 
R 
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yeast microarray data 


yeast microarray data 
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(a) (b) 


Figure 21.5: (a) Some yeast gene expression data plotted as a heat map. (b) Same data plotted as a time 
series. Generated by yeast data_viz.ipynb. 


Hierarchical Clustering of Profiles 
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Figure 21.6: Hierarchical clustering applied to the yeast gene expression data. (a) The rows are permuted 
according to a hierarchical clustering scheme (average link agglomerative clustering), in order to bring similar 
rows close together. (b) 16 clusters induced by cutting the average linkage tree at a certain height. Generated 
by hclust_ yeast_ demo.ipynb. 


a certain height, we get the 16 clusters shown in Figure 21.6(b). The time series assigned to each 
cluster do indeed “look like” each other. 


21.2.3 Extensions 


There are many extensions to the basic HAC algorithm. For example, [Mon+21] present a more 
scalable version of the bottom up algorithm that builds sub-clusters in parallel. And g [Mon+19] 
discusses an online version of the algorithm, that can cluster data as it arrives, while reconsidering 
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previous clustering decisions (as opposed to only making greedy decisions). Under certain assumptions, 
this can provably recover the true underlying structure. This can be useful for clustering “mentions” 
of “entities” (such as people or things) in streaming text data. (This problem is called entity 
discovery.) 


21.3 K means clustering 


There are several problems with hierarchical agglomerative clustering (Section 21.2). First, it takes 
O(N?) time (for the average link method), making it hard to apply to big datasets. Second, it 
assumes that a dissimilarity matrix has already been computed, whereas the notion of “similarity” is 
often unclear and needs to be learned. Third, it is just an algorithm, not a model, and so it is hard 
to evaluate how good it is. That is, there is no clear objective that it is optimizing. 

In this section, we discuss the K-means algorithm [Mac67; Llo82], which addresses these issues. 
First, it runs in O(N KT) time, where T is the number of iterations. Second, it computes similarity in 
terms of Euclidean distance to learned cluster centers p, € RP, rather than requiring a dissimilarity 
matrix. Third, it optimizes a well-defined cost function, as we will see. 


21.3.1 The algorithm 


We assume there are K cluster centers u, E€ RP, so we can cluster the data by assigning each data 
point £n € RP to it closest center: 


z% = arg min |en — mll (21.13) 


Of course, we don’t know the cluster centers, but we can estimate them by computing the average 
value of all points assigned to them: 


1 
w= Yoan (21.14) 


We can then iterate these steps to convergence. 
More formally, we can view this as finding a local minimum of the following cost function, known 
as the distortion: 


N 
J(M,Z) = Ñ` ||æn — u., |? = |X - ZM™|3 (21.15) 
n=1 


where X € RN*”?, Z e [0,1])%**, and M € R?** contains the cluster centers py in its columns. 
K-means optimizes this using alternating minimization. (This is closely related to the EM algorithm 
for GMMs, as we discuss in Section 21.4.1.1.) 


21.3.2 Examples 


In this section, we give some examples of K-means clustering. 
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distortion = 2800.54 distortion = 2978.45 


Figure 21.7: Illustration of K-means clustering in 2d. We show the result of using two different random seeds. 
Adapted from Figure 9.5 of [Gér19]. Generated by kmeans_ voronot.ipynb. 
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Figure 21.8: Clustering the yeast data from Figure 21.5 using K-means clustering with K = 16. (a) Visualizing 
all the time series assigned to each cluster. (b) Visualizing the 16 cluster centers as prototypical time series. 
Generated by kmeans_yeast_demo.ipynb. 


21.3.2.1 Clustering points in the 2d plane 


Figure 21.7 gives an illustration of K-means clustering applied to some points in the 2d plane. We see 
that the method induces a Voronoi tessellation of the points. The resulting clustering is sensitive 
to the initialization. Indeed, we see that the lower quality clustering on the right has higher distortion. 
By default, sklearn uses 10 random restarts (combined with the K-means++ initialization described 
in Section 21.3.4) and returns the clustering with lowest distortion. (In sklearn, the distortion is 
called the “inertia”.) 


21.3.2.2 Clustering gene expression time series data from yeast cells 


In Figure 21.8, we show the result of applying K-means clustering with K = 16 to the 300 x 7 yeast 
time series matrix shown in Figure 21.5. We see that time series that “look similar” to each other are 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


722 Chapter 21. Clustering 


(a) (b) (c) 


Figure 21.9: An image compressed using vector quantization with a codebook of size K. (a) K = 2. (b) 
K = 4. (c) Original uncompressed image. Generated by vgDemo.ipynb. 


assigned to the same cluster. We also see that the centroid of each cluster is a reasonabe summary 
all the data points assigned to that cluster. Finally we notice that group 6 was not used, since no 
points were assigned to it. However, this is just an accident of the initialization process, and we are 
not guaranteed to get the same clustering, or number of clusters, if we repeat the algorithm. (We 
discuss good ways to initialize the method in Section 21.3.4, and ways to choose K in Section 21.3.7.) 


21.3.3 Vector quantization 


Suppose we want to perform lossy compression of some real-valued vectors, £» € RP. A very simple 
approach to this is to use vector quantization or VQ. The basic idea is to replace each real-valued 
vector £n € R? with a discrete symbol zn € {1,..., K}, which is an index into a codebook of K 
prototypes, 4, € RP. Each data vector is encoded by using the index of the most similar prototype, 
where similarity is measured in terms of Euclidean distance: 


encode(#,,) = arg min |En — ugl? (21.16) 


We can define a cost function that measures the quality of a codebook by computing the recon- 
struction error or distortion it induces: 


J&— 15 ee decode(encode(æn))||? = rY len- nll? (21.17) 


n=1 


where decode(k) = p,. This is exactly the cost function that is minimized by the K-means algorithm. 

Of course, we can achieve zero distortion if we assign one prototype to every data vector, by using 
K = N and assigning p,, = £n. However, this does not compress the data at all. In particular, it 
takes O(N DB) bits, where N is the number of real-valued data vectors, each of length D, and B is 
the number of bits needed to represent a real-valued scalar (the quantization accuracy to represent 
each £n). 

We can do better by detecting similar vectors in the data, creating prototypes or centroids for 
them, and then representing the data as deviations from these prototypes. This reduces the space 
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requirement to O(N log, K + KDB) bits. The O(N log, K) term arises because each of the N data 
vectors needs to specify which of the K codewords it is using; and the O(k DB) term arises because 
we have to store each codebook entry, each of which is a D-dimensional vector. When N is large, the 
first term dominates the second, so we can approximate the rate of the encoding scheme (number of 
bits needed per object) as O(log, K), which is typically much less than O(DB). 

One application of VQ is to image compression. Consider the 200 x 320 pixel image in Figure 21.9; 
we will treat this as a set of N = 64,000 scalars. If we use one byte to represent each pixel (a 
gray-scale intensity of 0 to 255), then B = 8, so we need NB = 512,000 bits to represent the image in 
uncompressed form. For the compressed image, we need O(N log, K) bits. For K = 4, this is about 
128kb, a factor of 4 compression, yet it results in negligible perceptual loss (see Figure 21.9(b)). 

Greater compression could be achieved if we modeled spatial correlation between the pixels, e.g., if 
we encoded 5x5 blocks (as used by JPEG). This is because the residual errors (differences from the 
model’s predictions) would be smaller, and would take fewer bits to encode. This shows the deep 
connection between data compression and density estimation. See the sequel to this book, [Mur23], 
for more information. 


21.3.4 The K-means+-+ algorithm 


K-means is optimizing a non-convex objective, and hence needs to be initialized carefully. A simple 
approach is to pick K data points at random, and to use these as the initial values for pp. We can 
improve on this by using multiple restarts, i.e., we run the algorithm multiple times from different 
random starting points, and then pick the best solution. However, this can be slow. 

A better approach is to pick the centers sequentially so as to try to “cover” the data. That is, 
we pick the initial point uniformly at random, and then each subsequent point is picked from the 
remaining points, with probability proportional to its squared distance to the point’s closest cluster 
center. That is, at iteration t, we pick the next cluster center to be x, with probability 


Di- (£n) 
pli = £n) = (21.18) 
l Mnai Dt-1(&n’) 
where 
t-1 5 
D,(a) = min ||æ — ull (21.19) 


is the squared distance of x to the closest existing centroid. Thus points that are far away from a 
centroid are more likely to be picked, thus reducing the distortion. This is known as farthest point 
clustering [Gon85], or K-means+-+ [AV07; Bah+12; Bac+16; BLK17; LS19a]. Surprisingly, this 
simple trick can be shown to guarantee that the recontruction error is never more than O(log K) 
worse than optimal [AV07]. 


21.3.5 The K-medoids algorithm 


There is a variant of K-means called K-medoids algorithm, in which we estimate each cluster center 
Hg by choosing the data example £n E€ ¥ whose average dissimilarity to all other points in that 
cluster is minimal; such a point is known as a medoid. By contrast, in K-means, we take averages 
over points x, € R? assigned to the cluster to compute the center. K-medoids can be more robust to 
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outliers (although that issue can also be tackled by using mixtures of Student distributions, instead 
of mixtures of Gaussians). More importantly, K-medoids can be applied to data that does not live in 
R?, where averaging may not be well defined. In K-medoids, the input to the algorithm is N x N 
pairwise distance matrix, D(n,n’), not an N x D feature matrix. 

The classic algorithm for solving the K-medoids is the partitioning around medoids or PAM 
method [KR87]. In this approach, at each iteration, we loop over all K medoids. For each medoid 
m, we consider each non-medoid point o, swap m and o, and recompute the cost (sum of all the 
distances of points to their medoid). If the cost has decreased, we keep this swap. The running time 
of this algorithm is O(N?KT), where T is the number of iterations. 

There is also a simpler and faster method, known as the Voronoi iteration method due to [PJ09]. 
In this approach, at each iteration, we have two steps, similar to K-means. First, for each cluster 
k, look at all the points currently assigned to that cluster, Sk = {N : zn = k}, and then set mg to 
be the index of the medoid of that set. (To find the medoid requires examining all |S;,| candidate 
points, and choosing the one that has the smallest sum of distances to all the other points in Sx.) 
Second, for each point n, assign it to its closest medoid, zn = argmin, D(n, k). The pseudo-code is 
given in Algorithm 21.2. 


Algorithm 21.2: K-medoids algorithm 


1 Initialize mı: as a random subset of size K from {1,...,N} 
2 repeat 

3 Zn = argmin, d(n,m,) for n =1:N 

4 Mr = argminn.,,—k n:e =k UM) fork =1:K 

5 until converged 


21.3.6 Speedup tricks 


K-means clustering takes O(N KI) time, where J is the number of iterations, but we can reduce the 
constant factors using various tricks. For example, [Elk03] shows how to use the triangle inequality 
to keep track of lower and upper bounds for the distances between inputs and the centroids; this 
can be used to eliminate some redundant computations. Another approach is to use a minibatch 
approximation, as proposed in [Scul0]. This can be significantly faster, although can result in slightly 
worse loss (see Figure 21.10). 


21.3.7 Choosing the number of clusters K 


In this section, we discuss how to choose the number of clusters K in the K-means algorithm and 
other related methods. 
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Figure 21.10: Illustration of batch vs mini-batch K-means clustering on the 2d data from Figure 21.7. 
Left: distortion vs K. Right: Training time vs K. Adapted from Figure 9.6 of [Gér19]. Generated by 
kmeans_ minibatch.ipynb. 
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Figure 21.11: Performance of K-means and GMM vs K on the 2d dataset from Figure 21.7. (a) Distortion 
on validation set vs K. Generated by kmeans_ silhouette.ipynb. (b) BIC vs K. Generated by gmm_ 2d.ipynb. 
(c) Silhouette score vs K. Generated by kmeans_ silhouette. ipynb. 


21.3.7.1 Minimizing the distortion 


Based on our experience with supervised learning, a natural choice for picking K is to pick the value 
that minimizes the reconstruction error on a validation set, defined as follows: 


1 x 
ert (Dyalia, K) = Duaal y IEZ = nlli (21.20) 
vandi nEDyalid 


where ĉn = decode(encode(z,,)) is the reconstruction of xp. 

Unfortunately, this technique will not work. Indeed, as we see in Figure 21.11a, the distortion 
monotonically decreases with K. To see why, note that the K-means model is a degenerate density 
model which consists of K “spikes” at the uy, centers. As we increase K, we “cover” more of the input 
space. Hence any given input point is more likely to find a close prototype to accurately represent 
it as K increases, thus decreasing reconstruction error. Thus unlike with supervised learning, we 
cannot use reconstruction error on a validation set as a way to select the best unsupervised model. 
(This comment also applies to picking the dimensionality for PCA, see Section 20.1.4.) 
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21.3.7.2 Maximizing the marginal likelihood 


A method that does work is to use a proper probabilistic model, such as a GMM, as we describe in 
Section 21.4.1. We can then use the log marginal likelihood (LML) of the data to perform model 
selection. 

We can approximate the LML using the BIC score as we discussed in Section 5.2.5.1. From 
Equation (5.59), we have 


BIC(K) = log p(D|6,) — aK log( N) (21.21) 


where Dx is the number of parameters in a model with K clusters, and ô x is the MLE. We see from 
Figure 21.11b that this exhibits the typical U-shaped curve, where the penalty decreases and then 
increases. 

The reason this works is that each cluster is associated with a Gaussian distribution that fills a 
volume of the input space, rather than being a degenerate spike. Once we have enough clusters to 
cover the true modes of the distribution, the Bayesian Occam’s razor (Section 5.2.3) kicks in, and 
starts penalizing the model for being unncessarily complex. 

See Section 21.4.1.3 for more discussion of Bayesian model selection for mixture models. 


21.3.7.3 Silhouette coefficient 


In this section, we describe a common heuristic method for picking the number of clusters in a 
K-means clustering model. This is designed to work for spherical (not elongated) clusters. First we 
define the silhouette coefficient of an instance i to be sc(i) = (b; — a;)/ max(a;, bi), where a; is the 
mean distance to the other instances in cluster k; = argmin, ||44, — x;||, and b; is the mean distance 
to the other instances in the next closest cluster, k; = argmin,4;, || Hp — vi||. Thus a; is a measure 
of compactness of it’s cluster, and b; is a measure of distance between the clusters. The silhouette 
coefficient varies from -1 to +1. A value of +1 means the instance is close to all the members of its 
cluster, and far from other clusters; a value of 0 means it is close to a cluster boundary; and a value 
of -1 means it may be in the wrong cluster. We define the silhouette score of a clustering K to be 
the mean silhouette coefficient over all instances. 

In Figure 21.11a, we plot the distortion vs K for the data in Figure 21.7. As we explained above, 
it goes down monotonically with K. There is a slight “kink” or “elbow” in the curve at K = 3, 
but this is hard to detect. In Figure 21.11c, we plot the silhouette score vs K. Now we see a more 
prominent peak at K = 3, although it seems K = 7 is almost as good. See Figure 21.12 for a 
comparison of some of these clusterings. 

It can be informative to look at the individual silhouette coefficients, and not just the mean score. 
We can plot these in a silhouette diagram, as shown in Figure 21.13, where each colored region 
corresponds to a different cluster. The dotted vertical line is the average coefficient. Clusters with 
many points to the left of this line are likely to be of low quality. We can also use the silhouette 
diagram to look at the size of each cluster, even if the data is not 2d. 


21.3.7.4 Incrementally growing the number of mixture components 


An alternative to searching for the best value of K is to incrementally “grow” GMMs. We can start 
with a small value of K, and after each round of training, we consider splitting the cluster with the 
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(c) (d) 


Figure 21.12: Voronoi diagrams for K-means for different K on the 2d dataset from Figure 21.7. Generated 
by kmeans_ silhouette.ipynb. 


highest mixing weight into two, with the new centroids being random perturbations of the original 
centroid, and the new scores being half of the old scores. If a new cluster has too small a score, or 
too narrow a variance, it is removed. We continue in this way until the desired number of clusters is 
reached. See [FJ02] for details. 


21.3.7.5 Sparse estimation methods 


Another approach is to pick a large value of K, and then to use some kind of sparsity-promoting 
prior or inference method to “kill off’ unneeded mixture components, such as variational Bayes. See 
the sequel to this book, [Mur23], for details. 


21.4 Clustering using mixture models 


We have seen how the K-means algorithm can be used to cluster data vectors in R?. However, 
this method assumes that all clusters have the same spherical shape, which is a very restrictive 
assumption. In addition, K-means assumes that all clusters can be described by Gaussians in the 
input space, so it cannot be applied to discrete data. By using mixture models (Section 3.5), we can 
overcome both of these problems, as we illustrate below. 
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Figure 21.13: Silhouette diagrams for K-means for different K on the 2d dataset from Figure 21.7. Generated 
by kmeans_ silhouette. ipynb. 


21.4.1 Mixtures of Gaussians 


Recall from Section 3.5.1 that a Gaussian mixture model (GMM) is a model of the form 


K 
p(w|8) = S> mN (alur, £x) (21.22) 
k=1 


If we know the model parameters 0 = (m, {up,©p}), we can use Bayes rule to compute the 
responsibility (posterior membership probability) of cluster k for data point £p: 


D(Zn = k|0)p(an|zn = k, 8) 
yy D(Zn = k!|0)p(@n|2n = ki, 0) 


Given the responsibilities, we can compute the most probable cluster assignment as follows: 


Tnk = D(Zn = klan, 0) = (21.23) 
Ên = arg MaX Tnk = arg max [log p(£n|zn = k, 0) + log p(zn = k|0)] (21.24) 
This is known as hard clustering. 


21.4.1.1 K-means is a special case of EM 


We can estimate the parameters of a GMM using the EM algorithm (Section 8.7.3). It turns out that 
the K-means algorithm is a special case of this algorithm, in which we make two approximations: 
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Figure 21.14: Some data in 2d fit using a GMM with K =5 components. Left column: marginal distribution 
p(x). Right column: visualization of each mixture distribution, and the hard assignment of points to their most 
likely cluster. (a-b) Full covariance. (c-d) Tied full covariance. (e-f) Diagonal covairance, (g-h) Spherical 


covariance. Color coding is arbitrary. Generated by gmm_ 2d.ipynb. 
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Figure 21.15: Some 1d data, with a kernel density estimate superimposed. Adapted from Figure 6.2 of [Mar18]. 
Generated by gmm_ identifiability. pymc3.ipynb. 
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Figure 21.16: Illustration of the label switching problem when performing posterior inference for the parameters 
of a GMM. We show a KDE estimate of the posterior marginals derived from 1000 samples from 4 HMC 
chains. (a) Unconstrained model. Posterior is symmetric. (b) Constrained model, where we add a penalty to 
ensure uo < pı. Adapted from Figure 6.6-6.7 of [Mar18]. Generated by gmm_ identifiability. pymc3.ipynb. 


we fix X, = I and 7, = 1/K for all the clusters (so we just have to estimate the means p,), and 
we approximate the E step, by replacing the soft responsibilities with hard cluster assignments, i.e., 
we compute z% = argmaXx, Tnk, and set rn, © 1(k = z} ) instead of using the soft responsibilities, 
Tnk = P(Zn = k|£n, 0). With this approximation, the weighted MLE problem in Equation (8.165) of 
the M step reduces to Equation (21.14), so we recover K-means. 

However, the assumption that all the clusters have the same spherical shape is very restrictive. 
For example, Figure 21.14 shows the marginal density and clustering induced using different shaped 
covariance matrices for some 2d data. We see that modeling this particular dataset needs the ability 
to capture off-diagonal covariance for some clusters (top row). 


21.4.1.2 Unidentifiability and label switching 


Note that we are free to permute the labels in a mixture model without changing the likelihood. This 
is called the label switching problem, and is an example of non-identifiability of the parameters. 
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This can cause problems if we wish to perform posterior inference over the parameters (as opposed 
to just computing the MLE or a MAP estimate). For example, suppose we fit a GMM with K = 2 
components to the data in Figure 21.15 using HMC. The posterior over the means, p({11, p2|D), is 
shown in Figure 21.16a. We see that the marginal posterior for each component, p(juz|D), is bimodal. 
This reflects the fact that there are two equally good explanations of the data: either pı ~ 47 and 
H2 © 57, or vice versa. 

To break symmetry, we can add an ordering constraint on the centers, so that uı < u2. We 
can do this by adding a penalty or potential function to the objective if the constraint is violated. 
More precisely, the penalized log joint becomes 


(8) = log p(D|@) + log p(@) + d(H) (21.25) 
where 
_ Joo if “1 < Ho 
olu) = i otherwise pat 20) 


This has the desired effect, as shown in Figure 21.16b. 

A more general approach is to apply a transformation to the parameters, to ensure identifiability. 
That is, we sample the parameters 0 from a proposal, and then apply an invertible transformation 
6’ = f(@) to them before computing the log joint, logp(D,6’). To account for the change of 
variables (Section 2.8.3), we add the log of the determinant of the Jacobian. In the case of a 1d 
ordering transformation, which just sorts its inputs, the determinant of the Jacobian is 1, so the 
log-det-Jacobian term vanishes. 

Unfortunately, this approach does not scale to more than 1 dimensional problems, because there is 
no obvious way to enforce an ordering constraint on the centers py- 


21.4.1.3 Bayesian model selection 


Once we have a reliable way to ensure identifiability, we can use Bayesian model selection techniques 
from Section 5.2.2 to select the number of clusters K. In Figure 21.17, we illustrate the results 
of fitting a GMM with K = 3 — 6 components to the data in Figure 21.15. We use the ordering 
transform on the means, and perform inference using HMC. We compare the resulting GMM model 
fits to the fit of a kernel density estimate (Section 16.3), which often over-smooths the data. We see 
fairly strong evidence for two bumps, corresponding to different subpopulations. 

We can compare these models more quantitatively by computing their WAIC scores (widely 
applicable information criterion) which is an approximation to the log marginal likelihood (see [Wat10; 
Wat13; VGG17] for details). The results are shown in Figure 21.18. (This kind of visualization was 
proposed in [McE20, p228].) We see that the model with K = 6 scores significantly higher than for 
the other models, although K = 5 is a close second. This is consistent with the plot in Figure 21.17. 


21.4.2 Mixtures of Bernoullis 


As we discussed in Section 3.5.2, we can use a mixtures of Bernoullis to cluster binary data. The 
model has the form 


D D 
p(y|z = k, @) = Il Ber(ya|Mak) = Il pue(1 — Har)”: (21.27) 
d—=1 di 
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Figure 21.17: Fitting GMMs with different numbers of clusters K to the data in Figure 21.15. Black solid 
line is KDE fit. Solid blue line is posterior mean; feint blue lines are posterior samples. Dotted lines show the 
individual Gaussian mixture components, evaluated by plugging in their posterior mean parameters. Adapted 
from Figure 6.8 of [Mar18]. Generated by gmm_chooseK_ pymc3.ipynb. 


Here uak is the probability that bit d turns on in cluster k. We can fit this model with EM, SGD, 
MCMC, etc. See Figure 3.13 for an example, where we cluster some binarized MNIST digits. 


21.5 Spectral clustering * 


In this section, we discuss an approach to clustering based on eigenvalue analysis of a pairwise 
similarity matrix. It uses the eigenvectors to derive feature vectors for each datapoint, which are 
then clustered using a feature-based clustering method, such as K-means (Section 21.3). This is 
known as spectral clustering [SM00; Lux07]. 


21.5.1 Normalized cuts 


We start by creating a weighted undirected graph W, where each data vector is a node, and the 
strength of the i — 7 edge is a measure of similarity. Typically we only connected a node to its most 
similar neighbors, to ensure the graph is sparse, which speeds computation. 

Our goal is to find K clusters of similar points. That is, we want to find a graph partition into 
S1,..., Sg disjoint sets of nodes so as to minimize some kind of cost. 

Our first attempt at a cost function is to compute the weight of connections between nodes in each 
cluster to nodes outside each cluster: 


K 
1 = 
cut(S1,...,SK) £ z y W (Sk, Sk) (21.28) 
k=1 


where W(A,B) = Dic, jeg Wij and Sk = V \ Sz is the complement of Sp, where V = {1,..., N}. 
Unfortunately the optimal solution to this often just partitions off a single node from the rest, 
since that minimizes the weight of the cut. To prevent this, we can divide by the size of each set, to 
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Figure 21.18: WAIC scores for the different GMMs. The empty circle is the posterior mean WAIC score for 
each model, and the black lines represent the standard error of the mean. The solid circle is the in-sample 
deviance of each model, i.e., the unpenalized log-likelihood. The dashed vertical line corresponds to the 
maximum WAIC value. The gray triangle is the difference in WAIC score for that model compared to the best 
model. Adapted from Figure 6.10 of [Mar18]. Generated by gmm_chooseK_ pymc3.ipynb. 


get the following objective, known as the normalized cut: 


cut( (Sk, Sk) 


21.2 
vol (Sk) ( B) 


Ncut(S1,..., SK) Ê D 


2 


where vol(A) X ;e4 di is the total weight of set A and d; = = wij is the weighted degree of 
node 7. This splits the graph into K clusters such that nodes within each cluster are similar to each 
other, but are different to nodes in other clusters. 

We can formulate the Ncut problem in terms of searching for binary vectors c; € {0,1} that 
minimizes the above objective, where cip = 1 iff point 7 belongs to cluster k. Unfortunately this is 
NP-hard [WW93]. Below we discuss a continuous relaxation of the problem based on eigenvector 


methods that is easier to solve. 


21.5.2 Ejigenvectors of the graph Laplacian encode the clustering 


In Section 20.4.9.2, we discussed the graph Laplacian, which is defined as L £ D — W, where W 
is a Symmetric weight matrix for the graph, and D = diag(d;) is a diagonal matrix containing the 
weighted degree of each node, d; = >> ; Wij To get some intuition as to why L might be useful for 
graph-based clustering, we note the following result. 


Theorem 21.5.1. The set of eigenvectors of L with eigenvalue 0 is spanned by the indicator vectors 
1s,,.--,ls,, where Sp are the K connected components of the graph. 


Proof. Let us start with the case K = 1. If f is an eigenvector with eigenvalue 0, then 0 = 
Ye wij (fi; — fi). If two nodes are connected, so w;; > 0, we must have that f; = fj. Hence f is 
constant for all vertices which are connected by a path in the graph. Now suppose K > 1. In this 
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Figure 21.19: Results of clustering some data. (a) K-means. (b) Spectral clustering. Generated by spec- 
tral_ clustering demo.ipynb. 


case, L will be block diagonal. A similar argument to the above shows that we will have K indicator 
functions, which “select out” the connected components. 


This suggests the following clustering algorithm. Compute the eigenvectors and values of L, and 
let U be an N x K matrix with the K eigenvectors with smallest eigenvalue in its columns. (Fast 
methods for computing such “bottom” eigenvectors are discussed in [YHJ09]). Let u; € R* be 
the tth row of U. Since these u; will be piecewise constant, we can apply K-means clustering 
(Section 21.3) to them to recover the connected components. (Note that the vectors u; are the same 
as those computed by Laplacian eigenmaps discussed in Section 20.4.9.) 

Real data may not exhibit such clean block structure, but one can show, using results from 
perturbation theory, that the eigenvectors of a “perturbed” Laplacian will be close to these ideal 
indicator functions [NJW01]. 

In practice, it is important to normalize the graph Laplacian, to account for the fact that some 
nodes are more highly connected than others. One way to do this (proposed in [NJW0O1]) is to create 
a symmetric matrix 


Leym 4 D7-?7LD~? = I - D7: WD? (21.30) 


This time the eigenspace of 0 is spanned by D21 są- This suggests the following algorithm: find the 
smallest K eigenvectors of Lsym, stack them into the matrix U, normalize each row to unit norm 
by creating tj; = uij/ y (©, U2) to make the matrix T, cluster the rows of T using K-means, then 
infer the partitioning of the original points. 


21.5.3 Example 


Figure 21.19 illustrates the method in action. In Figure 21.19(a), we see that K-means does a 
poor job of clustering, since it implicitly assumes each cluster corresponds to a spherical Gaussian. 
Next we try spectral clustering. We compute a dense similarity matrix W using a Gaussian kernel, 
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Wi; = exp(— 34 |æ: — æ;||2). We then compute the first two eigenvectors of the normalized Laplacian 
Lsym. From this we infer the clustering using K-means, with K = 2; the results are shown in 
Figure 21.19(b). 


21.5.4 Connection with other methods 


Spectral clustering is closely related to several other methods for unsupervised learning, some of 
which we discuss below. 


21.5.4.1 Connection with kPCA 


Spectral clustering is closely related to kernel PCA (Section 20.4.6). In particular, kPCA uses the 
largest eigenvectors of W; these are equivalent to the smallest eigenvectors of I — W. This is similar 
to the above method, which computes the smallest eigenvectors of L = D — W. See [Ben+04a] for 
details. In practice, spectral clustering tends to give better results than kPCA. 


21.5.4.2 Connection with random walk analysis 


In practice we get better results by computing the eigenvectors of the normalized graph Laplacian. 
One way to normalize the graph Laplacian, which is used in [SM00; Mei01], is to define 


Lu = DHL = I- DW (21.31) 


One can show that for Lw, the eigenspace of 0 is again spanned by the indicator vectors 1s, [Lux07], 
so we can perform clustering directly on the K smallest eigenvectors U. 

There is an interesting connection between this approach and random walks on a graph. First 
note that P = D-!'W = I — L,w is a stochastic matrix, where Pij = Wi;/d; can be interpreted as the 
probability of going from i to j. If the graph is connected and non-bipartite, it possesses a unique 
stationary distribution m = (m1,..., my), where m; = d;/vol(V), and vol(V) = 5°, d; is the sum of 
all the node degrees. Furthermore, one can show that for a partition of size 2, 


Neut(S,5) = p(S|S) + p(S|5) (21.32) 


This means that we are looking for a cut such that a random walk spends more time transitioning to 
similar points, and rarely makes transitions from § to S or vice versa. This analysis can be extended 
to K > 2; for details, see [Mci01]. 


21.6 Biclustering * 

In some cases, we have a data matrix X € RN**Ne and we want to cluster the rows and the columns; 
this is known as biclustering or coclustering. This is widely used in bioinformatics, where the 
rows often represent genes and the columns represent conditions. It can also be used for collaborative 
filtering, where the rows represent users and the columns represent movies. 

A variety of ad hoc methods for biclustering have been proposed; see [MO04] for a review. In 
Section 21.6.1, we present a simple probabilistic generative model in which we assign a latent cluster 
id to each row, and a differnet latent cluster id to each column. In Section 21.6.2, we extend this 
to the case where each row can belong to multiple clusters, depending on which groups of features 
(columns) we choose to use to define the different groups of objects (rows). 
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O1 killer whale, blue whale, humpback, seal, walrus, dolphin Fi 23456 
O2 antelope, horse, giraffe, zebra, deer 
03 monkey, gorilla, chimp 

04 hippo, elephant, rhino 

O5 grizzly bear, polar bear 


FA flippers, strain teeth, swims, arctic, coastal, ocean, water 
F2 hooves, long neck, horns 

F3 hands, bipedal, jungle, tree 

F4 bulbous body shape, slow, inactive 

F5 meat teeth, eats meat, hunter, fierce 

F6 walks, quadrapedal, ground 


Figure 21.20: Illustration of biclustering. We show 5 of the 12 organism clusters, and 6 of the 33 feature 
clusters. The original data matriz is shown, partitioned according to the discovered clusters. From Figure 3 
of [Kem+06]. Used with kind permission of Charles Kemp. 


21.6.1 Basic biclustering 


Here we present a simple probabilistic generative model for biclustering based on [Kem+06] (see also 
[SMMO03] for a related approach). The idea is to associate each row and each column with a latent 
indicator, u; € {1,...,Nu}, vj € {1,...,Nv}, where N, is the number of row clusters, and N, is the 
number of column clusters. We then use the following generative model: 


Na 
p(U) = J [ Uritu, Has Nah (21.33) 


Ne 
pV) = [J Cat(jl{1,..., No} (21.34) 


j=1 


Ny Ne 


i=1 j=1 


where Oa,» are the parameters for row cluster a and column cluster b. 

Figure 21.20 shows a simple example. The data has the form X;; = 1 iff animal 7 has feature 
j, where i = 1: 50 and j = 1: 85. The animals represent whales, bears, horses, etc. The 
features represent properties of the habitat (jungle, tree, coastal), or anatomical properties (has 
teeth, quadripedal), or behavioral properties (swims, eats meat), etc. The method discovered 12 
animal clusters and 33 feature clusters. ([Kem-+06] use a Bayesian nonparametric method to infer 
the number of clusters.) For example, the O2 cluster is { antelope, horse, giraffe, zebra, deer }, which 
is characterized by feature clusters F2 = { hooves, long neck, horns} and F6 = { walks, quadripedal, 
ground }, whereas the O4 cluster is { hippo, elephant, rhino }, which is characterized by feature 
clusters F4 = { bulbous body shape, slow, inactive } and F6. 


21.6.2 Nested partition models (Crosscat) 


The problem with basic biclustering (Section 21.6.1) is that each object (row) can only belong to 
one cluster. Intuitively, an object can have multiple roles, and can be assigned to different clusters 
depending on which subset of features you use. For example, in the animal dataset, we may want to 
group the animals on the basis of anatomical features (e.g., mammals are warm blooded, reptiles are 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


21.6. Biclustering * 737 


(a) (b) 


Figure 21.21: (a) Example of biclustering. Each row is assigned to a unique cluster, and each column is 
assigned to a unique cluster. (b) Example of multi-clustering using a nested partition model. The rows can 
belong to different clusters depending on which subset of column features we are looking at. 


not), or on the basis of behavioral features (e.g., predators vs prey). 

We now present a model that can capture this phenomenon. We illustrate the method with an 
example. Suppose we have a 6 x 6 matrix, with N, = 2 row clusters and N, = 3 column clusters. 
Furthermore, suppose the latent column assignments are as follows: v = [1,1,2,3,3,3]. This means 
we put columns 1 and 2 into group 1, column 3 into group 2, and columns 4 to 6 into group 3. For 
the columns that get clustered into group 1, we cluster the rows as follows: wu. = [1,1,1,2, 2,2]; For 
the columns that get clustered into group 2, we cluster the rows as follows: u. 2 = [1, 1, 2,2,2,2]; and 
for the columns that get clustered into group 3, we cluster the rows as follows: u. 3 = [1,1,1,1,1, 2]. 
The resulting partition is shown in Figure 21.21(b). We see that the clustering of the rows depends 
on which group of columns we choose to focus on. 

Formally, we can define the model as follows: 


Ne No 

p(U) = [TT] unital, serge Nah) (21.36) 
= i=l 

V) = []onittesita eee en's) (21.37) 


p(Z|U, V) = Wh ij = (peng 4) (21.38) 


P 


p(X|Z, 0) = Iho Xij|0z,;) (21.39) 


w=1j=1 


where Op are the parameters for cocluster k € {1,...,Nu} and lE {1,..., No}. 

This model was independently proposed in [Sha+06; Man+16] who call it crosscat (for cross- 
categorization), in [Gua+10; CFD10], who call it multi-clust, and in [RG11], who call it nested 
partitioning. In all of these papers, the authors propose to use Dirichlet processes, to avoid the 
problem of estimating the number of clusters. Here we assume the number of clusters is known, and 
show the parameters explicitly, for notational simplicity. 

Figure 21.22 illustrates the model applied to some binary data containing 22 animals and 106 
features. The figure shows the (approximate) MAP partition. The first partition of the columns 
contains taxonomic features, such as “has bones”, “is warm-blooded”, “lays eggs”, etc. This divides the 
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Figure 21.22: MAP estimate produced by the crosscat system when applied to a binary data matriz of animals 
(rows) by features (columns). See text for details. From Figure 7 of [Sha+06]. Used with kind permission of 
Vikash Mansingkha. 


animals into birds, reptiles/ amphibians, mammals, and invertebrates. The second partition of the 
columns contains features that are treated as noise, with no apparent structure (except for the single 
row labeled “frog”). The third partition of the columns contains ecological features like “dangerous”, 
“carnivorous”, “lives in water”, etc. This divides the animals into prey, land predators, sea predators 
and air predators. Thus each animal (row) can belong to a different cluster depending on what set of 
features are considered. 


Draft of “Probabilistic Machine Learning: An Introduction”. June 22, 2023 


22 Recommender Systems 


Recommender systems are systems which recommend items (such as movies, books, ads) to 
users based on various information, such as their past viewing/ purchasing behavior (e.g., which 
movies they rated high or low, which ads they clicked on), as well as optional “side information” such 
as demographics about the user, or information about the content of the item (e.g., its title, genre 
or price). Such systems are widely used by various internet companies, such as Facebook, Amazon, 
Netflix, Google, etc. In this chapter, we give a brief introduction to the topic. More details can be 
found in e.g., [DKK12; Patl2; Yan+14; AC16; Agg16; Zha+19b].. 


22.1 Explicit feedback 


In this section, we consider the simplest setting in which the user gives explicit feedback to the 
system in terms of a rating, such as +1 or -1 (for like/dislike) or a score from 1 to 5. Let Y„; € R 
be the rating that user u gives to item 7. We can represent this as an M x N matrix, where M is 
the number of users, and N is the number of items. Typically this matrix will be very large but 
very sparse, since most users will not provide any feedback on most items. See Figure 22.1(a) for an 
example. We can also view this sparse matrix as a bipartite graph, where the weight of the u — i 
edge is Yui. This reflects the fact that we are dealing with relational data, i.e., the values of u and 
i have no intrinsic meaning (they are just arbitrary indices), it is the fact that u and i are connected 
that matters. 

If Yui is missing, it could be because user u has not interacted with item 7, or it could be that they 
knew they wouldn’t like it and so they chose not to engage with it. In the former case, some of the 
data is missing at random; in the latter case, the missingness is informative about the true value 
of Yui. (See e.g., [Mar-+11] for further discussion of this point.) We will assume the data is missing 
at random, for simplicity. 


22.1.1 Datasets 


A famous example of an explicit ratings matrix was made available by the movie streaming company 
Netflix. In 2006, they released a large dataset of 100,480,507 movie ratings (on a scale of 1 to 5) 
from 480,189 users of 17,770 movies. Despite the large size of the training set, the ratings matrix 
is still 99% sparse (unknown). Along with the data, they offered a prize of $1M, known as the 
Netflix Prize, to any team that could predict the true ratings of a set of test (user, item) pairs 
more accurately than their incumbent system. The prize was claimed on September 21, 2009 by 
a team known as “BellKor’s Pragmatic Chaos”. They used an ensemble of different methods, as 
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Figure 22.1: Example of a relational dataset represented as a sparse matrix (left) or a sparse bipartite graph 
(right). Values corresponding to empty cells (missing edges) are unknown. Rows 3 and 4 are similar to each 
other, indicating that users 3 and 4 might have similar preferences, so we can use the data from user 3 to 
predict user 4’s preferences. However, user 1 seems quite different in their preferences, and seems to give low 
ratings to all items. For user 2, we have very little observed data, so it is hard to make reliable predictions. 


described in [Kor09; BK07; FHK12]. However, a key component in their ensemble was the method 
described in Section 22.1.3. 

Unfortunately the Netflix data is no longer available due to privacy concerns. Fortunately the 
MovieLens group at the University of Minnesota have released an anonymized public dataset of 
movie ratings, on a scale of 1-5, that can be used for research [HK15]. There are also various 
other public explicit ratings datasets, such as the Jester jokes dataset from [Gol+-01] and the 
BookCrossing dataset from [Zie+05]. 


22.1.2 Collaborative filtering 


The original approach to the recommendation problem is called collaborative filtering [Gol+-92]. 
The idea is that users collaborate on recommending items by sharing their ratings with other users; 
then if u wants to know if they interact with i, they can see what ratings other users u’ have given 
to i, and take a weighted average: 


a= >) sim(uw) Vy, (22.1) 
wY ur 7A? 


where we assume Y,,,; =? if the entry is unknown. The traditional approach measured the similarity 
of two users by comparing the sets Su = {Yu 4? : i € T} and Sw = {Yw i A? : i € T}, where T 
is the set of items. However, this can suffer from data sparsity. In Section 22.1.3 we discuss an 
approach based on learning dense embedding vectors for each item and each user, so we can compute 
similarity in a low dimensional feature space. 
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22.1.3 Matrix factorization 


We can view the recommender problem as one of matrix completion, in which we wish to predict 
all the missing entries of Y. We can formulate this as the following optimization problem: 


L(Z)= X (Z-Y? =]|Z- Y|% (22.2) 
ig:Vig A? 


However, this is an under-specified problem, since there are an infinite number of ways of filling in 
the missing entries of Z. 

We need to add some constraints. Suppose we assume that Y is low rank. Then we can write it in 
the form Z = UV! ~x Y, where U is an M x K matrix, V is a N x K matrix, K is the rank of the 
matrix, M is the number of users, and N is the number of items. This corresponds to a prediction of 
the form by writing 


en i (22.3) 


This is called matrix factorization. 

If we observe all the Y;; entries, we can find the optimal Z using SVD (Section 7.5). However, 
when Y has missing entries, the corresponding objective is no longer convex, and does not have a 
unique optimum [SJ03]. We can fit this using alternating least squares (ALS), where we estimate 
U given V and then estimate V given U (for details, see e.g., [KBV09]). Alternatively we can just 
use SGD. 

In practice, it is important to also allow for user-specific and item-specific baselines, by writing 


Qui = H + by + ci + Ul ny (22.4) 


This can capture the fact that some users might always tend to give low ratings and others may give 
high ratings; in addition, some items (e.g., very popular movies) might have unusually high ratings. 
In addition, we can add some £> regularization to the parameters to get the objective 


L(8) = 5 (Yiz — Gig)? + Ab +e + |luull? + lloll?) (22.5) 
if:Yij #? 


We can optimize this using SGD by sampling a random (u, i) entry from the set of observed values, 
and performing the following updates: 


bu = bu + Neui — Abu) (22.6) 
Ci = Ci + (€ui — Aci) (22.7) 
Uy = Uy + N(CuiVi — Auu) (22.8) 
Vi = Vi + N(Cuitn — Avi) (22.9) 


where eui = Yui — Yui is the error term, and 7 > 0 is the learning rate. This approach was first 
proposed by Simon Funk, who was one of the first to do well in the early days of the Netflix 
competition. 1 


1. https://sifter.org/~simon/journal/20061211.html. 
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Figure 22.2: Visualization of the first two latent movie factors estimated from the Netflix challenge data. 
Each movie j is plotted at the location specified by vj. See text for details. From Figure 3 of [KBV09]. Used 
with kind permission of Yehuda Koren. 


22.1.3.1 Probabilistic matrix factorization (PMF) 


We can convert matrix factorization into a probabilistic model by defining 
P(Yui = y) = N (yl + bu + ci + unv 0°) (22.10) 


This is known as probabilistic matrix factorization (PMF) [SM08]. The NLL of this model 
is equivalent to the matrix factorization objective in Equation (22.2). However, the probabilistic 
perspective allows us to generalize the model more easily. For example, we can capture the fact 
that the ratings are integers (often mostly 0s), and not reals, using a Poisson or negative Binomial 
likelihood (see e.g., [GOF18]). This is similar to exponential family PCA (Section 20.2.7), except 
that we view rows and columns symmetrically. 


22.1.3.2 Example: Netflix 


Suppose we apply PMF to the Netflix dataset using K = 2 latent factors. Figure 22.2 visualizes the 
learned embedding vectors u; for a few movies. On the left of the plot we have low-brow humor and 
horror movies (Half Baked, Freddy vs Jason), and on the right we have more serious dramas (Sophie’s 
Choice, Moonstruck). On the top we have critically acclaimed independent movies (Punch-Drunk Love, 
I Heart Huckabees), and on the bottom we have mainstream Hollywood blockbusters (Armageddon, 
Runway Bride). The Wizard of Oz is right in the middle of these axes, since it is in some senses an 
“average movie”. 

Users are embedded into the same spaces as movies. We can then predict the rating for any 
user-video pair using proximity in the latent embedding space. 
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(a) (b) 


Figure 22.3: (a) A fragment of the observed ratings matrix from the MovieLens-1M dataset. (b) Predictions 
using SVD with 50 latent components. Generated by matrix factorization recommender. ipynb. 


22.1.3.3 Example: MovieLens 


Now suppose we apply PMF to the MovieLens-1M dataset with 6040 users, 3706 movies, and 1,000,209 
ratings. We will use K = 50 factors. For simplicity, we fit this using SVD applied to the dense 
ratings matrix, where we replace missing values with 0. (This is just a simple approximation to keep 
the demo code simple.) In Figure 22.3 we show a snippet of the true and predicted ratings matrix. 
(We truncate the predictions to lie in the range [1,5].) We see that the model is not particularly 
accurate, but does capture some structure in the data. 

Furthermore, it seems to behave in a qualitatively sensible way. For example, in Figure 22.4 we 
show the top 10 movies rated by a given user as well as the top 10 predictions for movies they had not 
seen. The model seems to have “picked up” on the underlying preferences of the user. For example, 
we see that many of the predicted movies are action or film-noir, and both of these genres feature in 
the user’s own top-10 list, even though explicit genre information is not used during model training. 


22.1.4 Autoencoders 


Matrix factorization is a (bi)linear model. We can make a nonlinear version using autoencoders. Let 
y:i € R™ be the th column of the ratings matrix, where unknown ratings are set to 0. We can 
predict this ratings vector using an autoencoder of the form 


F(Y: 0) = W'9(Vy.; +H) +6 (22.11) 


where V c RX™ maps the ratings to an embedding space, W € RX™ maps the embedding space to 
a distribution over ratings, p € RË are the biases of the hidden units, and b € R™ are the biases 
of the output units. This is called the (item-based) version of the AutoRec model [Sed+15]. This 
has 2M K + M + K parameters. There is also a user-based version, that can be derived in a similar 
manner, which has 2N kK + N + K parameters. (On MovieLens and Netflix, the authors find that 
the item-based method works better.) 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


744 


516 
1848 
596 
1235 
2085 
1188 
1198 
897 
1849 
581 


MovieID 
858 
1387 
2028 
1221 
913 
3417 
2186 
2791 
1188 
1304 


MovieID 
527 
1953 
608 
1284 
2194 
1230 
1242 
922 
1954 
593 


Chapter 22. Recommender Systems 


Title Genres 

Godfather, The (1972) Action|Crime|Drama 

Jaws (1975) Action|Horror 

Saving Private Ryan (1998) Action|Drama|War 
Godfather: Part Il, The (1974) Action|Crime|Drama 
Maltese Falcon, The (1941) Film-Noir|Mystery 


Crimson Pirate, The (1952) Adventure|Comedy|Sci-Fi 


Strangers on a Train (1951) Film-Noir|Thriller 
Airplane! (1980) Comedy 
Strictly Ballroom (1992) Comedy|Romance 


Butch Cassidy and the Sundance Kid (1969) Action|Comedy|Western 


(a) 
Title Genres 
Schindler's List (1993) Drama|War 


French Connection, The (1971) Action|Crime|Drama|Thriller 


Fargo (1996) Crime|Drama|Thriller 
Big Sleep, The (1946) Film-Noir|Mystery 
Untouchables, The (1987) Action|Crime|Drama 
Annie Hall (1977) Comedy/Romance 
Glory (1989) Action|Drama|War 
Sunset Blvd. (a.k.a. Sunset Boulevard) (1950) Film-Noir 
Rocky (1976) Action|Drama 
Silence of the Lambs, The (1991) Drama|Thriller 
(b) 


Figure 22.4: (a) Top 10 movies (from a list of 69) that user “837” has already highly rated. (b) Top 10 
predictions (from a list of 3637) from the algorithm. Generated by matrix_factorization_ recommender.ipynb. 


We can fit this by only updating parameters that are associated with the observed entries of y. ;. 
Furthermore, we can add an %2 regularizer to the weight matrices to get the objective 


N 
LO) =D) J, (ui Fua Ou)? + à wik + IIVI) (22.12) 


t=1 wn 


Despite the simplicity of this method, the authors find that this does better than more complex 
methods such as restricted Boltzmann machines (RBMs, [SMH07]) and local low-rank matrix 
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approximation (LLORMA, [Lee+13]). 


22.2 Implicit feedback 


So far, we have assumed that the user gives explicit ratings for each item that they interact with. 
This is a very restrictive assumption. More generally, we would like to learn from the implicit 
feedback that users give just by interacting with a system. For example, we can treat the list of 
movies that user u watches as positives, and regard all the other movies as negatives. Thus we get a 
sparse, positive-only ratings matrix. 

Alternatively, we can view the fact that they watched movie i but did not watch movie j as an 
implicit signal that they prefer i to j. The resulting data can be represented as a set of tuples of the 
form yn = (u, i,j), where (u,2) is a positive pair, and (u, j) is a negative (or unlabeled) pair. 


22.2.1 Bayesian personalized ranking 


To fit a model to data of the form (u,i, j), we need to use a ranking loss, so that the model ranks i 
ahead of j for user u. A simple way to do this is to use a Bernoulli model of the form 


Plyn = (u, i, j)|0) = o(f(u, i; 0) — f(u, j;0)) (22.13) 
If we combine this with a Gaussian prior for 0, we get the following MAP estimation problem: 
£(0)= J, logo(f(u,é:0) — f(u, j;0))— allall? (22.14) 
(u,i,j)ED 


where D = {(u,i, j): i € T}, j € T\ Z}, where Z? are the set of all items that user u selected, and 
T\ Ty are all the other items (which they may dislike, or simply may not have seen). This is known 
as Bayesian personalized ranking or BPR [Ren+09]. 

Let us consider this example from [Zha+20, Sec 16.5]. There are 4 items in total, Z = {71, i2, i3, ia}, 
and user u chose to interact with Zt = {i2,i3}. In this case, the implicit item-item preference matrix 
for user u has the form 


¥,= = (22.15) 


Po poy 


where Y, ix = + means user u prefers 7’ to i, Yu ix = — means user u prefers i to 7’, and Y,, ii =? 
means we cannot tell what the user’s preference is. For example, focusing on the second column, we 
see that this user rates 72 higher than 7, and i4, since they selected 72 but not 21 or i4; however, we 
cannot tell if they prefer iz over 73 or vice versa. 

When the set of posssible items is large, the number of negatives in Z \ Z} can be very large. 
Fortunately we can approximate the loss by subsampling negatives. 

Note that an alternative to the log-loss above is to use a hinge loss, similar to the approach used 
in SVMs (Section 17.3). This has the form 


Lyn = (u, 4,9); f) = max (m — (f(u, i) — f(u, j)), 0) = max (m — f(u, i) + f(u, j),0) (22.16) 
where m > 0 is the safety margin. This tries to ensure the negative items j never score more than m 
higher than the positive items 7. 
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22.2.2 Factorization machines 


The AutoRec approach of Section 22.1.4 is nonlinear, but treats users and items asymmetrically. 
In this section, we discuss a more symmetric discriminative modeling approach. We start with a 
linear version. The basic idea is to predict the output (such as a rating) for any given user-item pair, 
x = [one-hot (u), one-hot(2)], using 


D D D 
fle) =n+ > wait d> X (woei (22.17) 
i=1 


i=1 j=i+1 


where x € R? where D = (M + N) is the number of inputs, V € R?** is a weight matrix, w € R? 
is a weight vector, and u € R is a global offset. This is known as a factorization machine (FM) 
[Ren12]. 

The term (v; vj)x;xzj measures the interaction between feature i and j in the input. This generalizes 
the matrix factorization model of Equation (22.4), since it can handle other kinds of information in 
the input æ, beyond just user and item, as we discuss in Section 22.3. 

Computing Equation (22.17) takes O(K D?) time, since it considers all possible pairwise interactions 
between every user and every item. Fortunately we can rewrite this so that we can compute it in 
O(K D) time as follows: 


D D 1 D D 1 D 
5 5 (vlv; xix; = 3 So (el es) 219; = z X (wl vi)titi (22.18) 
i=l g=i4+1 i=1 j=1 i=l 
1 D D K D K 
~ 9 DODD 5 XO Y Vikviptini (22.19) 
i=1 j=1 k=1 i=1 k=1 


1 K D D 
=A (o vezi)? — >> viet] (22.20) 


For sparse vectors, the overall complexity is linear in the number of non-zero components. So if we 
use one-hot encodings of the user and item id, the complexity is just O(K), analogous to the original 
matrix factorization objective of Equation (22.4). 

We can fit this model to minimize any loss we want. For example, if we have explicit feedback, we 
may choose MSE loss, and if we have implicit feedback, we may choosing ranking loss. 

In [Guo+17], they propose a model called deep factorization machines, which combines the 
above method with an MLP applied to a concatenation of the embedding vectors, instead of the 
inner product. More precisely, it is a model of the form 


f(x; @) = o(FM(a) + MLP (zx)) (22.21) 
This is closely related to the wide and deep model proposed in [Che+16]. The idea is that 
the bilinear FM model captures explicit interactions between specific users and items (a form of 


memorization), whereas the MLP captures implicit interactions between user features and item 
features, which allows the model to generalize. 
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Figure 22.5: Illustration of the neural matrix factorization model. From Figure 2 of [He+17]. Used with kind 
permission of Xiangnan He. 


22.2.3 Neural matrix factorization 


In this section, we describe the neural matrix factorization model of [He+17]. This is another 
way to combine bilinear models with deep neural networks. The bilinear part is used to define the 
following generalized matrix factorization (GMF) pathway, which computes the following feature 
vector for user u and item i: 


zli = Pu: O Qe: (22.22) 


where P € RM® is a user embedding matrix, and Q € R‘* is an item embedding matrix. The DNN 
part is just an MLP applied to a concatenation of the embedding vectors (using different embedding 
matrices): 


zai = MLP([U,,., Vi] (22.23) 
Finally, the model combines these to get 
F(u, i; 0) = o(w" [zii Zu] (22.24) 


See Figure 22.5 for an illustration. 
In [He+17], the model is trained on implicit feedback, where yui = 1 if the interaction of user u 
with item 2 is observed, and Yui = 0 otherwise. However, it could be trained to minimize BPR loss. 


22.3 Leveraging side information 


So far, we have assumed that the only information available to the predictor are the integer id of the 
user and the integer id of the item. This is an extremely impoverished representation, and will fail to 
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Figure 22.6: Illustration of a design matrix for a movie recommender system, where we show the id of the 
user and movie, as well as other side information. From Figure 1 of [Ren12]. Used with kind permission of 
Stefen Rendle. 


work if we encounter a new user or new item (the so-called cold start problem). To overcome this, 
we need to leverage “side information”, beyond just the id of the user/item. 

There are many forms of side information we can use. For items, we often have rich meta-data, 
such text (e.g., title), images (e.g., cover), high-dimensional categorical variables (e.g., location), or 
just scalars (eg., price). For users, the side information available depends on the specific form of the 
interactive system. For search engines, it is the list of queries the user has issued, and (if they are 
logged in), information derived from websites they have visited (which is tracked via cookies). For 
online shopping sites, it is the list of searches plus past viewing and purchasing behavior. For social 
networking sites, there is information about the friendship graph of each user. 

It is very easy to capture this side information in the factorization machines framework, by 
expanding our definition of æ beyond the two one-hot vectors, as illustrated in Figure 22.6. The 
same input encoding can of course be fed into other kinds of models, such as deepFM or neuralMF. 

In addition to features about the user and item, there may be other contextual features, such as 
the time of the interaction (e.g., the day or evening). The order (sequence) of the most recently 
viewed items is often also a useful signal. The “Convolutional Sequence Embedding Recommendation” 
or Caser model proposed in [TW18] captures this by embedding the last M items, and then treating 
the M x K input as an image, by using a convolutional layer as part of the model. 

Many other kinds of neural models can be designed for the recommender task. See e.g., [Zha+19b] 
for a review. 


22.4  Exploration-exploitation tradeoff 


An interesting “twist” to recommender systems that does not arise in other kinds of prediction 
problems is the fact that the data that the system is trained on is a consequence of recommendations 
made by earlier versions of the system. Thus there is a feedback loop [Bot+13]. For example, consider 
the YouTube video recommendation system [CAS16]. There are millions of videos on the site, so the 
system must come up with a shortlist, or “slate”, of videos to show the user, to help them find what 
they want (see e.g., [Ie+19]). If the user watches one of these videos, the system can consider this 
positive feedback that it made a good recommendation, and it can update the model parameters 
accordingly. However, maybe there was some other video that the user would have liked even more? 
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It is impossible to answer this counterfactual unless the system takes a chance and shows some 
items for which the user response is uncertain. This is an example of the exploration-exploitation 
tradeoff. 

In addition to needing to explore, the system may have to wait for a long time until it can detect 
if a change it made its recommendation policies was beneficial. It is common to use reinforcement 
learning to learn policies which optimize long-term reward. See the sequel to this book, [Mur23], 
for details. 
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This chapter is coauthored with Bryan Perozzi, Sami Abu-El-Haija and Ines Chami, and is based on 
[Cha+21]. 


23.1 Introduction 


We now turn our focus to data which has semantic relationships between training samples {x,,}/_,. 
The relationships (known as edges) connect training samples (nodes) with an application specific 
meaning (commonly similarity). Graphs provide the mathematical foundations for reasoning about 
these kind of relationships 

Graphs are universal data structures that can represent complex relational data (composed of 
nodes and edges), and appear in multiple domains such as social networks, computational chemistry 
[Gil+17], biology [Sta+06], recommendation systems [KSJ09], semi-supervised learning [GB18], and 
others. 

Let A € {0,1}%*% be the adjacency matrix, where N is the number of nodes, and let W € RNN 
be a weighted version. In the methods we discuss below, some set W = A while others set W toa 
transformation of A, such as row-wise normalization. Finally, let X € RN*? be a matrix of node 
features. 

When designing and training a neural network model over graph data, we desire the designed 
method be applicable to nodes which participate in different graph settings (e.g. have differing 
connections and community structure). Contrast this with a neural network model designed for images, 
where each pixel (node) has the same neighborhood structure. By contrast, an arbitrary graph has no 
specified alignment of nodes, and further, each node might have a different neighborhood structure. 
See Figure 23.1 for a comparison. Consequently, operations like Euclidean spatial convolution cannot 
be directly applied on irregular graphs: Euclidean convolutions strongly rely on geometric priors 
(such as shift invariance), which don’t generalize to non-Euclidean domains. 

These challenges led to the development of Geometric Deep Learning (GDL) research [Bro+17b], 
which aims at applying deep learning techniques to non-Euclidean data. In particular, given the 
widespread prevalence of graphs in real-world applications, there has been a surge of interest in 
applying machine learning methods to graph-structured data. Among these, Graph Represen- 
tation Learning (GRL) [Cha+21] methods aim at learning low-dimensional continuous vector 
representations for graph-structured data, also called embeddings. 

We divide GRL here into two classes of problems: unsupervised and supervised (or semi- 
supervised) GRL. The first class aims at learning low-dimensional Euclidean representations optimizing 
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Figure 23.1: An illustration of Euclidean vs. non-Euclidean graphs. Used with permission from [Cha+21]. 


an objective, e.g. one that preserve the structure of an input graph. The second class also learns 
low-dimensional Euclidean representations but for a specific downstream prediction task such as 
node or graph classification. Further, the graph structure can be fixed throughout training and 
testing, which is known as the transductive learning setting (e.g. predicting user properties in a 
large social network), or alternatively the model is expected to answer questions about graphs not 
seen during training, known as the inductive learning setting (e.g. classifying molecular structures). 
Finally, while most supervised and unsupervised methods learn representations in Euclidean vector 
spaces, there recently has been interest for non-Euclidean representation learning, which aims 
at learning non-Euclidean embedding spaces such as hyperbolic or spherical spaces. The main 
motivations for this body of work is to use a continuous embedding space that resembles the 
underlying discrete structure of the input data it tries to embed (e.g. the hyperbolic space is a 
continuous version of trees [Sar11]). 


23.2 Graph Embedding as an Encoder/Decoder Problem 


While there are many approaches to GRL, many methods follow a similar pattern. First, the network 
input (node features X € RY*P and graph edges in A or W € RY) is encoded from the discrete 
domain of the graph into a continuous representation (embedding), Z € RY*Ł, Next, the learned 
representation Z is used to optimize a particular objective (such as reconstructing the links of the 
graph). In this section we will use the graph encoder-decoder model (GRAPHEDM) proposed by 
Chami et al. [Cha+21] to analyze popular families of GRL methods. 

The GRAPHEDM framework (Figure 23.2, [Cha+21]) provides a general framework that encap- 
sulates a wide variety of supervised and unsupervised graph embedding methods: including ones 
utilizing the graph as a regularizer (e.g. [ZG02]), positional embeddings(e.g. [PARS14]), and graph 
neural networks such as ones based on message passing [Gil+-17; Sca+09] or graph convolutions 
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Figure 23.2: Illustration of the GRAPHEDM framework from Chami et al. [Cha+21]. Based on the supervision 
available, methods will use some or all of the branches. In particular, unsupervised methods do not leverage 
label decoding for training and only optimize the similarity decoder (lower branch). On the other hand, 
semi-supervised and supervised methods leverage the additional supervision to learn models’ parameters (upper 
branch). Reprinted with permission from [Cha+21]. 


[Bru+14; KW16a]). 

The GRAPHEDM framework takes as input a weighted graph W € RX”, and optional node 
features X € RN*”. In (semi-)supervised settings, we assume that we are given training target labels 
for nodes (denoted N), edges (denoted Æ), and/or for the entire graph (denoted G). We denote the 
supervision signal as S € {N, E, G}, as presented below. 

The GRAPHEDM model itself can be decomposed into the following components: 


e Graph encoder network ENCez : RN*N x RN*? — RNXL, parameterized by OF, which 
combines the graph structure with optional node features to produce a node embedding matrix 
Z c R** as follows: 


Z = ENC(W, X; O”). (23.1) 


As we shall see next, this node embedding matrix might capture different graph properties 
depending on the supervision used for training. 


e Graph decoder network DECgp : RNXŁ > RNXN, parameterized by ©? , which uses the 
node embeddings Z to compute similarity scores for all node pairs in matrix W € RYN as 
follows: 


W = DEC(Z; OP). (23.2) 


e Classification network DECos : RN? > RNXIYI, where y is the label space. This network is 
used in (semi-)supervised settings and parameterized by ©”. The output is a distribution over 
the labels °, using node embeddings, as follows: 


7° = DEC(Z; 9°). (23.3) 


Specific choices of the aforementioned (encoder and decoder) networks allows GRAPHEDM to realize 
specific graph embedding methods, as we explain in the next subsections. 
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Figure 23.3: Shallow embedding methods. The encoder is a simple embedding look-up and the graph structure 
is only used in the loss function. Reprinted with permission from [Cha+21]. 


The output of a model, as described by GRAPHEDM framework, is a reconstructed graph similarity 
matrix W (often used to train unsupervised embedding algorithms), and/or labels 7° for supervised 
applications. The label output space y is application dependent. For instance, in node-level 
classification, 9N € Y”, with Y representing the node label space. Alternately, for edge-level labeling, 
JE e YN*N, with Y representing the edge label space. Finally, we note that other kinds of labeling 
are possible, such as graph-level labeling (where we would say JC € V, with Y representing the graph 
label space). 

Finally, a loss must be specified. This can be used to optimize the parameters © = {0",0?, 0%}. 
GRAPHEDM models can be optimized using a combination of three different terms. First, a supervised 
loss term, Lop; compares the predicted labels 7° to the ground truth labels y9. Next, a graph 
reconstruction loss term, Lg Recon, May leverage the graph structure to impose regularization 
constraints on the model parameters. Finally, a weight regularization loss term, LreG, allows 
representing priors on trainable model parameters for reducing overfitting. Models realizable by 
GRAPHEDM framework are trained by minimizing the total loss £ defined as: 


L= aLSup (y5, 95; ©) + bBLa recon W, W; 0) + yLrec (O), (23.4) 


where a, 8 and y are hyper-parameters, that can be tuned or set to zero. Note that graph embedding 
methods can be trained in a supervised (a £0) or unsupervised (a = 0) fashion. Supervised graph 
embedding approaches leverage an additional source of information to learn embeddings such as node 
or graph labels. On the other hand, unsupervised network embedding approaches rely on the graph 
structure only to learn node embeddings. 


23.3 Shallow graph embeddings 


Shallow embedding methods are transductive graph embedding methods, where the encoder function 
maps categorical node IDs onto a Euclidean space through an embedding matrix. Each node v; € V 
has a corresponding low-dimensional learnable embedding vector Z; € R4 and the shallow encoder 
function is 


Z=ENC(0“) £ OF where O” e RN*#, (23.5) 


Crucially, the embedding dictionary Z is directly learned as model parameters. In the unsupervised 
case, embeddings Z are optimized to recover some information about the input graph (e.g., the 
adjacency matrix W, or some transformation of it). This is somewhat similar to dimensionality 
reduction methods, such as PCA (Section 20.1), but for graph data structures. In the supervised 
case, the embeddings are optimized to predict some labels, for nodes, edges and/or the whole graph. 
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23.3.1 Unsupervised embeddings 


In the unsupervised case, we will consider two main types of shallow graph embedding methods: 
distance-based and outer product-based. Distance-based methods optimize the embedding dictionary 
Z = OË e RX** such that nodes i and j which are close in the graph (as measured by some 
graph distance function) are embedded in Z such that d2(Z;,Z;) is small, where do(.,.) is a pairwise 
distance function between embedding vectors. The distance function d2(-,-) can be customized, which 
can lead to Euclidean (Section 23.3.2) or non-Euclidean (Section 23.3.3) embeddings. The decoder 
outputs a node-to-node matrix W = DEC(Z; OP), with Wi; = d2 (Z;, Z;). 

Alternatively, some methods rely on pairwise dot-products to compute node similarities. The 
decoder network can be written as: W = DEC(Z;0”) = ZZ". 

In both cases, unsupervised embeddings for distance- and product-based methods are learned by 
minimizing the graph regularization loss: 


La REcon(W, W; ©) = di(s(W), W), (23.6) 


where s(W) is an optional transformation of the adjacency matrix W, and d, is pairwise distance 
function between matrices, which does not need to be of the same form as dz. As we shall see, there 
are many plausible choices for s,d,,dz. For instance, we can let s be the adjacency matrix itself, 
s(W) = W or a power of it e.g. s(W) = W?. If the input is a weighted binary matrix W = A, we 
can set s(W) = 1 — W, so that connected nodes with A;; = 1 get a weight (distance) of 0. 


23.3.2 Distance-based: Euclidean methods 


Distance-based methods minimize Euclidean distances between similar (connected) nodes. We give 
some examples below. 

Multi-dimensional scaling (MDS, Section 20.4.4) is equivalent to setting s(W) to some distance 
matrix measuring the dissimilarity between nodes (e.g. proportional to pairwise shortest distance) 
and then defining 


dy(s(W), W) = X (sW); — Wiz)? = ||s(W) - WI (23.7) 


i,j 


where Wi; = də(Z;, Z;) = ||Z; — Z;|| (although other distance metrics are plausible). 
Laplacian eigenmaps (Section 20.4.9) learn embeddings by solving the generalized eigenvector 
problem 


min tr(Z'L Z) st. Z'DZ=I and Z'D1=0 (23.8) 
ZERIV|x4 


where L = D — W is the graph Laplacian (Section 20.4.9.2), and D is a diagonal matrix containing 
the sum across columns for each row. The first constraint removes an arbitrary scaling factor in the 
embedding and the second one removes trivial solutions corresponding to the constant eigenvector 
(with eigenvalue zero for connected graphs). Further, note that tr(Z'LZ) = PT Wi;||Zi — Z; ||, 
where Z; is the ith row of Z; therefore the minimization objective can be equivalently written as a 
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graph reconstruction term, as follows: 


dy(s(W), W) = So Wi; x Wi; (23.9) 
ij 
Wij = 42(Z;, Z;) = ||Zi — Z513 (23.10) 


where s(W) = W. 


23.3.3 Distance-based: non-Euclidean methods 


So far, we have discussed methods which assume that embeddings lie in an Euclidean Space. However, 
recent work has considered hyperbolic geometry for graph embedding. In particular, hyperbolic 
embeddings are ideal for embedding trees and offer an exciting alternative to Euclidean geometry for 
graphs that exhibit hierarchical structures. We give some examples below. 

Nickel and Kiela [NK17] learn embeddings of hierarchical graphs using the Poincaré model of 
hyperbolic space. This is simple to represent in our notation as we only need to change d2(Z;, Z;) to 
the Poincaré distance function: 


lZ; — Z;ll 
də(Z;, Z ;) = dpoincaré(Zi, Z ;) = arcosh(1 +2 . (23.11) 
i (1 — zA — IZ) 


The optimization then learns embeddings which minimize distances between connected nodes while 
maximizing distances between disconnected nodes: 


-Ry 


d (W, W) = X` W;jlog (23.12) 
i,j 


k|Wip=0 e Wak 
where the denominator is approximated using negative sampling. Note that since the hyperbolic 
space has a manifold structure, care needs to be taken to ensure that the embeddings remain on the 
manifold (using Riemannian optimization techniques [Bon18]). 

Other variants of these methods have been proposed. Nickel and Kiela [NK18] explore the Lorentz 
model of hyperbolic space , and show that it provides better numerical stability than the Poincaré 
model. Another line of work extends non-Euclidean embeddings to mixed-curvature product spaces 
[Gu+18], which provide more flexibility for other types of graphs (e.g. ring of trees). Finally, work 
by Chamberlain, Clough, and Deisenroth [CCD17] extends Poincaré embeddings using skip-gram 
losses with hyperbolic inner products. 


23.3.4 Outer product-based: Matrix factorization methods 


Matrix factorization approaches learn embeddings that lead to a low rank representation of some 
similarity matrix s(W), with s : RY*N > R‘*%. The following are frequent choices: s(W) = W, 
s(W) = L (Graph Laplacian), or other proximity measure such as the Katz centrality index, Common 
Neighbors or Adamic/Adar index. 

The decoder function in matrix factorization methods is just a dot product: 


W = DEC(Z; O?) = ZZ" (23.13) 
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Figure 23.4: An overview of the pipeline for random-walk graph embedding methods. Reprinted with permission 
from [God18]. 


Matrix factorization methods learn Z by minimizing a regularization loss La, Recon(W, W; 0) = 
||s(W) — WI Ip. 

The graph factorization method of [Ahm-+13] learns a low-rank factorization of a graph by 
minimizing the graph regularization loss LG REcon(W, W; 0) = Vwi ee Wij — Wi). 

Note that if A is the binary adjacency matrix, (A,; = 1 iff (v;, vj) € E and A;; = 0 otherwise), 
the graph regularization loss can be expressed in terms of the Frobenius norm: 


Larecon(W, W; 0) = || A © (W - W)||3, (23.14) 


where © is the element-wise matrix multiplication operator. Therefore, GF also learns a low-rank 
factorization of the adjacency matrix W measured in Frobenuis norm. We note that this is a sparse 
operation (summing only over edges which exist in the graph), and so the method has computational 
complexity O(M). 

The methods described so far are all symmetric, that is, they assume that W;; = W;;. This is 
a limiting assumption when working with directed graphs as some relationships are not reciprocal. 
The GraRep method of [CLX15] overcomes this limitation by learning two embeddings per node, a 
source embedding Z5 and a target embedding Zt, which capture asymmetric proximity in directed 
networks. In addition to asymmetry, GraRep learns embeddings that preserve k-hop neighborhoods 
via powers of the adjacency matrix and minimizes a graph reconstruction loss with: 


WA = gosg (23.15) 

La recon(W, W®; 0) = |D- w+ — W®|)2, (23.16) 
for each 1 < k < K. GraRep concatenates all representations to get source embeddings Z° = 
[Zs]... |Z4-s] and target embeddings Zt = [Z| ...|Z-4], Unfortunately, GraRep is not very 


scalable, since it uses a matrix power, D~!W, making it increasingly more dense. This limitation 
can be circumvented by using implicit matrix factorization [Per+17] as discussed below. 


23.3.5 Outer product-based: Skip-gram methods 


Skip-gram graph embedding models were inspired by research in natural language processing to model 
the distributional behavior of words [Mik+13c; PSM14b]. Skip-gram word embeddings are optimized 
to predict words in their context (the surrounding words) for each target word in a sentence. Given 
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a sequence of words (w1,..., wr), skip-gram will minimize the objective: 
L=- SY 2 log P(we_i|we), 
~K<i<KiZ0 


for each target words wg. These conditional probabilities can be efficiently estimated using neural 
networks. See Section 20.5.2.2 for details. 

This idea has been leveraged for graph embeddings in the DeepWalk framework of [PARS14]. 
They justified this by showing empirically how the frequency statistics induced by random walks 
in real graphs follow a distribution similar to that of words used in natural language. In terms of 
GRAPHEDM, skip-gram graph embedding methods use an outer product (Equation 23.13) as their 
decoder function and a graph reconstruction term computed over random walks on the graph. 

In more detail, DeepWalk trains node embeddings to maximize the probability of predicting contezt 
nodes for each center node. The context nodes are nodes appearing adjacent to the center node, in 
simulated random walks on A. To train embeddings, DeepWalk generates sequences of nodes using 
truncated unbiased random walks on the graph—which can be compared to sentences in natural 
language models—and then maximize their log-likelihood. Each random walk starts with a node 
vi, € V and repeatedly samples the next node uniformly at random: u,,,, E€ {v E€ V | (v; v) € E}. 
The walk length is a hyperparameter. All generated random-walks can then be encoded by a sequence 
model. This two-step paradigm introduced by [PARS14] has been followed by many subsequent 
works, such as node2vec [GL16]. 

We note that it is common for underlying implementations to use two distinct representations for 
each node, one for when a node is center of a truncated random walk, and one when it is in the 
context. The implications of this modeling choice is studied further in |AEHPAR17]. 

To present DeepWalk in the GRAPHEDM framework, we can set: 


s(W) =E, |(D-'w)"] with q ~ P(Q) = Categorical((1,2,...,Tnax]) (23.17) 


where P(Q =q) = Tmas- bra (see [AEH+18] for the derivation). 


Training DeepWalk is equivalent to minimizing: 


La recon(W, W; ©) = log Z(Z) = 5 s(W) i; Wiz, (23.18) 
viE V, vjEV 


where W = ZZ", and the partition function is given by Z (Z2)=IL}; exp(W;,) can be approximated 
in O(N) time via hierarchical softmax (see Section 20.5.2). (It is also common to model W= (ay 
for directed graphs using embedding dictionaries Zout, Zin E€ RN *".) 

As noted by [LG14], Skip-gram methods can be viewed as implicit matrix factorization, and 
the methods discussed here are related to those of Matrix Factorization (see Section 23.3.4). This 
relationship is discussed in depth by [Qiu+18], who propose a general matrix factorization framework, 
NetMF, which uses the same underlying graph proximity information as DeepWalk, LINE [Tan+ 15], 
and node2vec [GL16]. Casting the node embedding problem as matrix factorization can inherit 
benefits of efficient sparse matrix operations [Qiu+19al. 
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23.3.6 Supervised embeddings 


In many applications, we have labeled data in addition to node features and graph structure. While 
it is possible to tackle a supervised task by first learning unsupervised representations and then using 
them as features in a secondary model, this is not the ideal workflow. Unsupervised node embeddings 
might not preserve important properties of graphs (e.g., node neighborhoods or attributes), that are 
most useful for a downstream supervised task. 

In light of this limitation, a number of methods combining these two steps, namely learning 
embeddings and predicting node or graph labels, have been proposed. Here, we focus on simple 
shallow methods. We discuss deep, nonlinear embeddings later on. 


23.3.6.1 Label propagation 


Label propagation (LP) [ZG02] is a very popular algorithm for graph-based semi-supervised node 
classification. The encoder is a shallow model represented by a lookup table Z. LP uses the label 
space to represent the node embeddings directly (i.e. the decoder in LP is simply the identity 
function): 


gN = DEC(Z; 0°) = Z. 


In particular, LP uses the graph structure to smooth the label distribution over the graph by adding a 
regularization term to the loss function, using the underlying assumption that neighbor nodes should 
have similar labels (i.e. there exist some label consistency between connected nodes). Laplacian 
eigenmaps are utilized in the regularization to enforce this smoothness: 


Lo.necon(W, W; 8) = X Willy’ -971 (23.19) 
i,j 


LP minimizes this energy function over the space of functions that take fixed values on labeled 
nodes (i.e. 9N = yN Vilv; € Vz) using an iterative algorithm that updates an unlabeled node’s label 
distribution via the weighted average of its neighbors’ labels. 

Label spreading (LS) [Zho+04] is a variant of label propagation which minimizes the following 
energy function: 
gy oN 
where D; = >> j Wij is the degree of node vj. 


In both methods, the supervised loss is simply the sum of distances between predicted labels and 
ground truth labels (one-hot vectors): 


2 
La recon(W, W; 0) = 5 Wi; 


i,j 


(23.20) 


’ 
2 


Leup(y%,9%;0) = X lly -ori (23.21) 


ilv;EVL 


Note that while the regularization term is computed over all nodes in the graph, the supervised loss is 
computed over labeled nodes only. These methods are expected to work well with consistent graphs, 
that is graphs where node proximity in the graph is positively correlated with label similarity. 
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23.4 Graph Neural Networks 


An extensive area of research focuses on defining convolutions over graph data. In the notation of 
Chami et al. [Cha+21], these (semi-)supervised neighborhood aggregation methods can be represented 
by an encoder of the form Z = ENC(X, W; 0%), and decoders of the form W= DEC(Z; OP) and/or 
J9 = DEC(Z; 0%). There are many models in this family; we review some of them below. 


23.4.1 Message passing GNNs 


The original graph neural network (GNN) model of [GMS05; Sca+09] was the first formulation of 
deep learning methods for graph-structured data. It views the supervised graph embedding problem 
as an information diffusion mechanism, where nodes send information to their neighbors until some 
stable equilibrium state is reached. More concretely, given randomly initialized node embeddings Z°, 
it applies the following recursion: 


z+! = ENC(X, W, Zt; O”), (23.22) 


where parameters OF are reused at every iteration. After convergence (t = T), the node embeddings 
ZT are used to predict the final output such as node or graph labels: 


g° = DEC(X, ZT; O°). (23.23) 


This process is repeated several times and the GNN parameters OF and ©” are learned with 
backpropagation via the Almeda-Pineda algorithm [Alm87; Pin88]. By Banach’s fixed point theorem, 
this process is guaranteed to converge to a unique solution when the recursion provides a contraction 
mapping. In light of this, Scarselli et al. [Sca+09] explore maps that can be expressed using message 
passing networks: 


z= XO f(X, X; Zi 9”), (23.24) 
jl(vivj)EE 


where f(-) is a multi-layer perception (MLP) constrained to be a contraction mapping. The decoder 
function, however, has no constraints and can be any MLP. 

Li et al. [Li+-15] propose Gated Graph Sequence Neural Networks (GGSNNs), which remove 
the contraction mapping requirement from GNNs. In GGSNNs, the recursive algorithm in Equation 
23.22 is relaxed by applying mapping functions for a fixed number of steps, where each mapping 
function is a gated recurrent unit [Cho+14b] with parameters shared for every iteration. The GGSNN 
model outputs predictions at every step, and so is particularly useful for tasks which have sequential 
structure (such as temporal graphs). 

Gilmer et al. [Gil+-17] provide a framework for graph neural networks called message passing 
neural networks (MPNNs), which encapsulates many recent models. In contrast with the GNN 
model which runs for an indefinite number of iterations, MPNNs provide an abstraction for modern 
approaches, which consist of multi-layer neural networks with a fixed number of layers. At every 
layer ¢, message functions f’(.) receive messages from neighbors (based on neighbor’s hidden state), 
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which are then passed to aggregation functions h*(.): 


mit SCH HY) a) 
j\(vivj)EB 
Hot = h (HÉ, mét!), (23.26) 


where H? = X. After £ layers of message passing, nodes’ hidden representations encode information 
within -hop neighborhoods. 

Battaglia et al. [Bat+18] propose GraphNet, which further extends the MPNN framework to 
learn representations for edges, nodes and the entire graph using message passing functions. The 
explicit addition of edge and graph representations adds additional expressivity to the MPNN model, 
and allows the application of graph models to additional domains. 


23.4.2 Spectral Graph Convolutions 


Spectral methods define graph convolutions using the spectral domain of the graph Laplacian matrix. 
These methods broadly fall into two categories: spectrum-based methods, which explicitly compute an 
eigendecomposition of the Laplacian (e.g., spectral CNNs [Bru+14]) and spectrum-free methods, 
which are motivated by spectral graph theory but do not actually perform a spectral decomposition 
(e.g., Graph convolutional networks or GCN [KW1é6a]). 

A major disadvantage of spectrum-based methods is that they rely on the spectrum of the graph 
Laplacian and are therefore domain-dependent (i.e. cannot generalize to new graphs). Moreover, 
computing the Laplacian’s spectral decomposition is computationally expensive. Spectrum-free 
methods overcome these limitations by utilizing approximations of these spectral filters. However, 
spectrum-free methods require using the whole graph W, and so do not scale well. 

For more details on spectral approaches, see e.g., [Bro+17b; Cha+2]1]. 


23.4.3 Spatial Graph Convolutions 


Spectrum-based methods have an inherent domain dependency which limits the application of a model 
trained on one graph to a new dataset. Additionally, spectrum-free methods (e.g. GCNs) require 
using the entire graph A, which can quickly become unfeasible as the size of the graph grows. 

To overcome these limitations, another branch of graph convolutions (spatial methods) borrow 
ideas from standard CNNs — applying convolutions in the spatial domain as defined by the graph 
topology. For instance, in computer vision, convolutional filters are spatially localized by using fixed 
rectangular patches around each pixel. Combined with the natural ordering of pixels in images (top, 
left, bottom, right), it is possible to reuse filters’ weights at every location. This process significantly 
reduces the total number of parameters needed for a model. While such spatial convolutions cannot 
directly be applied in graph domains, spatial graph convolutions take inspiration from them. The core 
idea is to use neighborhood sampling and attention mechanisms to create fixed-size graph patches, 
overcoming the irregularity of graphs. 


23.4.3.1 Sampling-based spatial methods 


To overcome the domain dependency and storage limitations of GCNs, Hamilton, Ying, and Leskovec 
[HYL17| propose GraphSAGE, a framework to learn inductive node embeddings. Instead of 
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1. Sample neighborhood 2. Aggregate feature information 3. Predict graph context and label 
from neighbors using aggregated information 


Figure 23.5: Illustration of the GraphSAGE model. Reprinted with permission from [HYL17]. 


averaging signals from all one-hop neighbors (via multiplications with the Laplacian matrix), SAGE 
samples fixed neighborhoods (of size q) for each node. This removes the strong dependency on fixed 
graph structure and allows generalization to new graphs. At every SAGE layer, nodes aggregate 
information from nodes sampled from their neighborhood (see Figure 23.5). In the GRAPHEDM 
notation, the propagation rule can be written as: 


HH = o(Q{H!, + O$AGG({H£; | vj € Sample(nbr(v;),q)})), (23.27) 


where AGG(-) is an aggregation function. This aggregation function can be any permutation invariant 
operator such as averaging (SAGE-mean) or max-pooling (SAGE-pool). As SAGE works with fixed 
size neighborhoods (and not the entire adjacency matrix), it also reduces the computational complexity 
of training GCNs. 


23.4.3.2 Attention-based spatial methods 


Attention mechanisms (Section 15.4) have been successfully used in language models where they, for 
example, allow models to identify relevant parts of long sequence inputs. Inspired by their success 
in language, similar ideas have been proposed for graph convolution networks. Such graph-based 
attention models learn to focus their attention on important neighbors during the message passing 
step via parametric patches which are learned on top of node features. This provides more flexibility 
in inductive settings, compared to methods that rely on fixed weights such as GCNs. 

The Graph attention network (GAT) model of [Vel+18] is an attention-based version of GCNs. 
At every GAT layer, it attends over the neighborhood of each node and learns to selectively pick 
nodes which lead to the best performance for some downstream task. The intuition behind this is 
similar to SAGE [HYL17] and makes GAT suitable for inductive and transductive problems. However 
unlike SAGE, which limits the convolution step to fixed size-neighborhoods, GAT allows each node 
to attend over the entirety of its neighbors — assigning each of them different weights. The attention 
parameters are trained through backpropagation, and the attention scores are then row-normalized 
with a softmax activation. 
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23.4.3.3 Geometric spatial methods 


Monti et al. [Mon+17] propose MoNet, a general framework that works particularly well when the 
node features lie in a geometric space, such as 3D point clouds or meshes. MoNet learns attention 
patches using parametric functions in a pre-defined spatial domain (e.g. spatial coordinates), and 
then applies convolution filters in the resulting graph domain. 

MoNet generalizes spatial approaches which introduce constructions for convolutions on manifolds, 
such as the Geodesic CNN (GCNN) [Mas~+15] and the Anisotropic CNN (ACNN) [Bos+16]. Both 
GCNN and ACNN use fixed patches that are defined on a specific coordinate system and therefore 
cannot generalize to graph-structured data. However, the MoNet framework is more general; any 
pseudo-coordinates (i.e. node features) can be used to induce the patches. More formally, if U* are 
pseudo-coordinates and H‘ are features from another domain, the MoNet layer can be expressed in 
our notation as: 


K 
H = o( EW o a (0R'ok), (23.28) 
k=1 


where g,(U*) are the learned parametric patches, which are N x N matrices. In practice, MoNet 
uses Gaussian kernels to learn patches, such that: 


(U* — M'ER (U — m) ). (23.29) 


1 
gx(U*) = ap( -3 


where u, and Xp are learned parameters, and Xx is restricted to be diagonal. 


23.4.4 Non-Euclidean Graph Convolutions 


As we discussed in Section 23.3.3, hyperbolic geometry enables learning of shallow embeddings of 
hierarchical graphs which have smaller distortion than Euclidean embeddings. However, one major 
downside of shallow embeddings is that they do not generalize well (if at all) across graphs. On the 
other hand, Graph Neural Networks, which leverage node features, have achieved good results on 
many inductive graph embedding tasks. 

It is natural then, that there has been recent interest in extending Graph Neural Networks to learn 
non-Euclidean embeddings. One major challenge in doing so again revolves around the nature of 
convolution itself. How should we perform convolutions in a non-Euclidean space, where standard 
operations such as inner products and matrix multiplications are not defined? 

Hyperbolic Graph Convolution Networks (HGCN) [Cha+19a] and Hyperbolic Graph Neural 
Networks (HGNN) [LNK19] apply graph convolutions in hyperbolic space by leveraging the Euclidean 
tangent space, which provides a first-order approximation of the hyperbolic manifold at a point. 
For every graph convolution step, node embeddings are mapped to the Euclidean tangent space at 
the origin, where convolutions are applied, and then mapped back to the hyperbolic space. These 
approaches yield significant improvements on graphs that exhibit hierarchical structure (Figure 23.6). 


23.5 Deep graph embeddings 


In this section, we use graph neural networks to devise graph embeddings in the unsupervised and 
semi-supervised cases. 
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Figure 23.6: Euclidean (left) and hyperbolic (right) embeddings of a tree graph. Hyperbolic embeddings 
learn natural hierarchies in the embedding space (depth indicated by color). Reprinted with permission from 


[Cha+19a]. 
P 


Figure 23.7: Unsupervised graph neural networks. Graph structure and input features are mapped to low- 
dimensional embeddings using a graph neural network encoder. Embeddings are then decoded to compute a 
graph regularization loss (unsupervised). Reprinted with permission from [Cha+21]. 
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23.5.1 Unsupervised embeddings 


In this section, we discuss unsupervised losses for GNNs, as illustrated in Figure 23.7. 


23.5.1.1 Structural deep network embedding 


The structural deep network embedding (SDNE) method of [WCZ16] uses auto-encoders which 
preserve first and second-order node proximity. The SDNE encoder takes a row of the adjacency 
matrix as input (setting s(W) = W) and produces node embeddings Z = ENC(W; 6”). (Note that 
this ignores any node features.) The SDNE decoder returns W= DEC(Z; OP), a reconstruction 
trained to recover the original graph adjacency matrix. SDNE preserves second order node proximity 
by minimizing the following loss: 


II(s(W) — W) -1(s(W) > 0) ||} + aspne X. 8(W)is||Zi — 253 (23.30) 


ij 


The first term is similar to the matrix factorization regularization objective, except that W is not 
computed using outer products. The second term is used by distance-based shallow embedding 
methods. 
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23.5.1.2 (Variational) graph auto-encoders 


Kipf and Welling [KW16b] use graph convolutions (Section 23.4.2) to learn node embeddings Z = 
GCN(W, X; 0). The decoder is an outer product: DEC(Z; ©?) = ZZ". The graph reconstruction 
term is the sigmoid cross entropy between the true adjacency and the predicted edge similarity scores: 


Lo.recon(W, W;@) = — ( J" (1 — Wj, )log(1 — o(Wi;)) + Wajlogo Ws) ) . (23.31) 


tj 


Computing the regularization term over all possible nodes pairs is computationally challenging 
in practice, so the Graph Auto Encoders (GAE) model uses negative sampling to overcome this 
challenge. 

Whereas GAE is a deterministic model, the authors also introduce variational graph auto-encoders 
(VGAE), which relies on variational auto-encoders (as in Section 20.3.5) to encode and decode 
the graph structure. In VGAE, the embedding Z is modeled as a latent variable with a standard 
multivariate normal prior p(Z) = N (Z|0, I) and a graph convolution is used as the amortized inference 
network, go(Z|/W, X). The model is trained by minimizing the corresponding negative evidence 
lower bound: 


NELBO(W, X; O) = —Eq4(z/w,x) [log p(W|Z)] + KL(go(Z|W, X)||p(Z)) (23.32) 
= La.recon(W, W: 0) + KL(qa(Z|W, X)||p(Z)). (23.33) 


23.5.1.3 Iterative generative modelling of graphs (Graphite) 


The graphite model of [GZE19] extends GAE and VGAE by introducing a more complex decoder. 
This decoder iterates between pairwise decoding functions and graph convolutions, as follows: 


= zz) 117 
IZ ||5 N 
Z+1) — GCN(W™), 7) 


where Z is initialized using the output of the encoder network. This process allows Graphite to 
learn more expressive decoders. Finally, similar to GAE, Graphite can be deterministic or variational. 


23.5.1.4 Methods based on contrastive losses 


The deep graph infomax method of [Vel+-19] is a GAN-like method for creating graph-level 
embeddings. Given one or more real (positive) graphs, each with its adjacency matrix W € RN*% 
and node features X € R“*”, this method creates fake (negative) adjacency matrices W~ e RN XN- 
and their features XT € RN XP., It trains (i) an encoder that processes both real and fake samples, 
respectively giving Z = ENC(X, W; O£) e RY*Ł and Z` = ENC(X-,W7;0”) e RX *Ł (ii) a 
(readout) graph pooling function R : RX? — R®, and (iii) a descriminator function D : R? x R4 > 
[0, 1] which is trained to output D(Z;,R(Z)) ~ 1 and D(Z; ,R(Z~)) ~ 0, respectively, for nodes 
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corresponding to given graph i € V and fake graph j € V~. Specifically, DGI optimizes: 


N 


ae 
e A PERRE sae REE (23.34) 


where © contains OË and the parameters of R, D. In the first expectation, DGI samples from 
the real (positive) graphs. If only one graph is given, it could sample some subgraphs from it 
(e.g. connected components). The second expectation samples fake (negative) graphs. In DGI, fake 
samples use the real adjacency W~ := W but fake features X~ are a row-wise random permutation of 
real X. The ENC used in DGI is a graph convolutional network, though any GNN can be used. The 
readout R summarizes an entire (variable-size) graph to a single (fixed-dimension) vector. Veličković 
et al. [Vel+19] use R as a row-wise mean, though other graph pooling might be used e.g. ones aware 
of the adjacency. 

The optimization of Equation (23.34) is shown by [Vel+19] to maximize a lower-bound on the 
mutual information between the outputs of the encoder and the graph pooling function, i.e., between 
individual node representations and the graph representation. 

In [Pen+20] they present a variant called Graphical Mutual Information. Rather than 
maximizing MI of node information and an entire graph, GMI maximizes the MI between the 
representation of a node and its neighbors. 


23.5.2  Semi-supervised embeddings 


In this section, we discuss semi-supervised losses for GNNs. We consider the simple special case in 
which we use a nonlinear encoder of the node features, but ignore the graph structure, i.e., we use 
Z = ENC(X; 0). 


23.5.2.1 SemiEmb 


[WRCO8] propose an approach called semi-supervised embeddings (SemiEmb) They use an 
MLP for the encoder of X. For the decoder, we can use a distance-based graph decoder: Wi a 
DEC(Z; OP); = ||Z; — Z;||?, where ||- || can be the L2 or L1 norm. 

SemiEmb regularizes intermediate or auxiliary layers in the network using the same regularizer 
as the label propagation loss in Equation (23.19). SemiEmb uses a feed forward network to predict 
labels from intermediate embeddings, which are then compared to ground truth labels using the 
Hinge loss. 


23.5.2.2 Planetoid 


Unsupervised skip-gram methods like DeepWalk and node2vec learn embeddings in a multi-step 
pipeline, where random walks are first generated from the graph and then used to learn embeddings. 
These embeddings are likely not optimal for downstream classification tasks. The Planetoid method 
of [YCS16] extends such random walk methods to leverage node label information during the 
embedding algorithm. 

Planetoid first maps nodes to embeddings Z = [Z°||Z”] = ENC(X; 0”) using a neural network 
(again ignoring graph structure). The node embeddings Z° capture structural information while the 
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node embeddings Z* capture feature information. There are two variants, a transductive version 
that directly learns Z° (as an embedding lookup), and an inductive model where Z° is computed 
with parametric mappings that act on input features X. The Planetoid objective contains both a 
supervised loss and a graph regularization loss. The graph regularization loss measures the ability to 
predict context using nodes embeddings: 


La recon( W, W; O) = -Eq loge (Wi) , (23.35) 


with W; = Zl Zj and y € {—1,1} with y = 1 if (v;, vj) € E is a positive pair and y = —1 if (vi, v;) 
is a negative pair. The distribution under the expectation is directly defined through a sampling 
process 

The supervised loss in Planetoid is the negative log-likelihood of predicting the correct labels: 


P 1 A 
£3up(¥", 9" ®) = — >O SO silog Git, (23.36) 
Ll ilvcV 1<k<C 


where i is a node’s index while k indicates label classes, and 7 are computed using a neural network 
followed by a softmax activation, mapping Z; to predicted labels. 


23.6 Applications 


There are many applications of graph embeddings, both unsupervised and supervised. We give some 
examples in the sections below. 


23.6.1 Unsupervised applications 


In this section, we discuss common unsupervised applications. 


23.6.1.1 Graph reconstruction 


A popular unsupervised graph application is graph reconstruction. In this setting, the goal is to 
learn mapping functions (which can be parametric or not) that map nodes onto a manifold which 
can reconstruct the graph. This is regarded as unsupervised in the sense that there is no supervision 
beyond the graph structure. Models can be trained by minimizing a reconstruction error, which is 
the error in recovering the original graph from learned embeddings. Several algorithms were designed 
specifically for this task, and we refer to Section 23.3.1 and Section 23.5.1 for some examples of 
reconstruction objectives. At a high level, graph reconstruction is similar to dimensionality reduction 
in the sense that the main goal is to summarize some input data into a low-dimensional embedding. 
Instead of compressing high dimensional vectors into low-dimensional ones as standard dimensionality 
reduction methods (e.g. PCA) do, the goal of graph reconstruction models is to compress data 
defined on graphs into low-dimensional vectors. 


23.6.1.2 Link prediction 


The goal in link prediction is to predict missing or unobserved links (e.g., links that may appear 
in the future for dynamic and temporal networks). Link prediction can also help identify spurious 
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Figure 23.8: A graph representation of some financial transactions. Adapted from http: //pgql-lang. org/ 
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links and remove them. It is a major application of graph learning models in industry, and common 
example of applications include predicting friendships in social networks predicting user-product 
interactions in recommendation systems, predicting suspicious links in a fraud detection 
system (see Figure 23.8), or predicting missing relationships between entities in a knowledge 
graph (see e.g., [Nic+15]). 

A common approach for training link prediction models is to mask some edges in the graph 
(positive and negative edges), train a model with the remaining edges and then test it on the masked 
set of edges. Note that link prediction is different from graph reconstruction. In link prediction, we 
aim at predicting links that are not observed in the original graph while in graph reconstruction, we 
only want to compute embeddings that preserve the graph structure through reconstruction error 
minimization. 

Finally, while link prediction has similarities with supervised tasks in the sense that we have labels 
for edges (positive, negative, unobserved), we group it under the unsupervised class of applications 
since edge labels are usually not used during training, but only used to measure the predictive quality 
of embeddings. 


23.6.1.3 Clustering 


Clustering is particularly useful for discovering communities and has many real-world applications. 
For instance, clusters exist in biological networks (e.g. as groups of proteins with similar properties), 
or in social networks (e.g. as groups of people with similar interests). 

The unsupervised methods introduced in this chapter can be used to solve clustering problems 
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by applying the clustering algorithm (e.g. k-means) to embeddings that are output by an encoder. 
Further, clustering can be joined with the learning algorithm while learning a shallow [Roz+19] or 
Graph Convolution [Chi+19a; CEL19] embedding model. 


23.6.1.4 Visualization 


There are many off-the-shelf tools for mapping graph nodes onto two-dimensional manifolds for the 
purpose of visualization. Visualizations allow network scientists to qualitatively understand graph 
properties, understand relationships between nodes or visualize node clusters. Among the popular 
tools are methods based on Force-Directed Layouts, with various web-app Javascript implementations. 

Unsupervised graph embedding methods are also used for visualization purposes: by first training 
an encoder-decoder model (corresponding to a shallow embedding or graph convolution network), and 
then mapping every node representation onto a two-dimensional space using t-SNE (Section 20.4.10) 
or PCA (Section 20.1). Such a process (embedding — dimensionality reduction) is commonly used to 
qualitatively evaluate the performance of graph learning algorithms. If nodes have attributes, one can 
use these attributes to color the nodes on 2D visualization plots. Good embedding algorithms embed 
nodes that have similar attributes nearby in the embedding space, as demonstrated in visualizations 
of various methods [PARS14; KW16a; AEH+18]. Finally, beyond mapping every node to a 2D 
coordinate, methods which map every graph to a representation [ARZP19] can similarly be projected 
into two dimensions to visualize and qualitatively analyze graph-level properties. 


23.6.2 Supervised applications 


In this section, we discuss common supervised applications. 


23.6.2.1 Node classification 


Node classification is an important supervised graph application, where the goal is to learn node 
representations that can accurately predict node labels. (This is sometimes called statistical 
relational learning [GTO07].) For instance, node labels could be scientific topics in citation networks, 
or gender and other attributes in social networks. 

Since labeling large graphs can be time-consuming and expensive, semi-supervised node classification 
is a particularly common application. In semi-supervised settings, only a fraction of nodes are labeled 
and the goal is to leverage links between nodes to predict attributes of unlabeled nodes. This setting 
is transductive since there is only one partially labeled fixed graph. It is also possible to do inductive 
node classification, which corresponds to the task of classifying nodes in multiple graphs. 

Note that node features can significantly boost the performance on node classification tasks if 
these are descriptive for the target label. Indeed, recent methods such as GCN (Section 23.4.2) 
GraphSAGE (Section 23.4.3.1) have achieved state-of-the-art performance on multiple node classifi- 
cation benchmarks due to their ability to combine structural information and semantics coming from 
features. On the other hand, other methods such as random walks on graphs fail to leverage feature 
information and therefore achieve lower performance on these tasks. 
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Figure 23.9: Structurally similar molecules do not necessarily have similar odor descriptors. (A) Lyral, the 
reference molecule. (B) Molecules with similar structure can share similar odor descriptors. (C) However, a 
small structural change can render the molecule odorless. (D) Further, large structural changes can leave the 
odor of the molecule largely unchanged. From Figure 1 of [SL+19], originally from [OPK 12]. Used with kind 
permission of Benjamin Sanchez-Lengeling. 


23.6.2.2 Graph classification 


Graph classification is a supervised application where the goal is to predict graph labels. Graph 
classification problems are inductive and a common example is classifying chemical compounds (e.g. 
predicting toxicity or odor from a molecule, as shown in Figure 23.9). 

Graph classification requires some notion of pooling, in order to aggregate node-level information 
into graph-level information. As discussed earlier, generalizing this notion of pooling to arbitrary 
graphs is non trivial because of the lack of regularity in the graph structure making graph pooling an 
active research area. In addition to the supervised methods discussed above, a number of unsupervised 
methods for learning graph-level representations have been proposed [Tsi+18; ARZP19; TMP20]. 
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A Notation 


A.1 Introduction 


It is very difficult to come up with a single, consistent notation to cover the wide variety of data, 
models and algorithms that we discuss in this book. Furthermore, conventions differ between different 
fields (such as machine learning, statistics and optimization), and between different books and papers 
within the same field. Nevertheless, we have tried to be as consistent as possible. Below we summarize 
most of the notation used in this book, although individual sections may introduce new notation. 
Note also that the same symbol may have different meanings depending on the context, although we 
try to avoid this where possible. 


A.2 Common mathematical symbols 
We list some common symbols below. 


Symbol Meaning 


oe) Infinity 

=> Tends towards, e.g., n —> œ 

x Proportional to, so y = ax can be written as y «x x 
4 Defined as 

O(-) Big-O: roughly means order of magnitude 
Z4 The positive integers 

R The real numbers 

R4 The positive reals 

SK The K-dimensional probability simplex 
S? M Cone of positive definite D x D matrices 
x Approximately equal to 

{1,..., N} The finite set {1,2,..., N} 

1:N The finite set {1,2,...,N} 

[2, u] The continuous interval {/ < x < u}. 
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A.3 Functions 


Generic functions will be denoted by f (and sometimes g or h). We will encounter many named 
functions, such as tanh(a) or a(x). A scalar function applied to a vector is assumed to be applied ele- 
mentwise, e.g., £? = [z7,..., 7%]. Functionals (functions of a function) are written using “blackboard” 
font, e.g., H(p) for the entropy of a distribution p. A function parameterized by fixed parameters 
0 will be denoted by f(x;0@) or sometimes fg(x). We list some common functions (with no free 
parameters) below. 


A.3.1 Common functions of one argument 


Symbol Meaning 


[x] Floor of x, i.e., round down to nearest integer 
[x] Ceiling of x, i.e., round up to nearest integer 
~a logical NOT 
(x) Indicator function, I(x) = 1 if x is true, else I(x) = 0 
ô(x) Dirac delta function, ô(x) = co if x = 0, else d(x) = 0 
x| Absolute value 
S| Size (cardinality) of a set 
n! Factorial function 


log(x) Natural logarithm of x 
exp(x) Exponential function e” 


T(x) Gamma function, T(x) = f u®~te~“du 
P(x) Digamma function, U(x) = + logT (£) 
a(x) Sigmoid (logistic) function, a 


A.3.2 Common functions of two arguments 


Symbol Meaning 
aNb logical AND 
aVob logical OR 


B(a,b) Beta function, B(a, b) = Hero 
Ce) n choose k, equal to n!/(k!(n — k)!) 

j Kronecker delta, equals I (i = 7) 
UOv Elementwise product of two vectors 
UAV Convolution of two vectors 


A.3.3 Common functions of > 2 arguments 


Symbol Meaning 
B(x) Multivariate beta function, Heres) 
r(x) Multi. gamma function, ?-/4 T]? T (£ + (1 — d)/2) 
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softmax(a#) Softmax function, [=o ]S 


Eo _, ote le=1 
A.4 Linear algebra 


In this section, we summarize the notation we use for linear algebra (see Chapter 7 for details). 


A.4.1 General notation 


Vectors are bold lower case letters such as x, w. Matrices are bold upper case letters, such as X, 
W. Scalars are non-bold lower case. When creating a vector from a list of N scalars, we write 
x = |z1,..., £y]; this may be a column vector or a row vector, depending on the context. (Vectors 
are assumed to be column vectors, unless noted otherwise.) When creating an M x N matrix from a 
list of vectors, we write X = [a 1,...,@y] if we stack along the columns, or X = [a ;...; £m] if we 
stack along the rows. 


A.4.2 Vectors 


Here is some standard notation for vectors. (We assume u and v are both N-dimensional vectors.) 


Symbol Meaning 

ulv Inner (scalar) product, DA Uii 

uv! Outer product (N x N matrix) 

U@v Elementwise product, [u1v1,..., UNUN] 
v! Transpose of v 

dim(v) Dimensionality of v (namely N) 
diag(v) Diagonal N x N matrix made from vector v 
lor ly Vector of ones (of length N) 

0 or On Vector of zeros (of length N) 

||| =||v||2 Euclidean or £2 norm 4/7, v? 

Poll & norm 3); [vil 


A.4.3 Matrices 


Here is some standard notation for matrices. (We assume S is a square N x N matrix, X and Y are 
of size M x N, and Z is of size M’ x N’.) 


Symbol Meaning 


X. j jth column of matrix 

Aas tth row of matrix (treated as a column vector) 
Xij Element (i, j) of matrix 

S>0 True iff S is a positive definite matrix 

tr(S) Trace of a square matrix 

det(S) Determinant of a square matrix 

|S| Determinant of a square matrix 
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s! Inverse of a square matrix 
Xİ Pseudo-inverse of a matrix 
xt Transpose of a matrix 


diag(S) Diagonal vector extracted from square matrix 
Ior Iy Identity matrix of size N x N 

XOY  Elementwise product 

X®Z Kronecker product (see Section 7.2.5) 


A.4.4 Matrix calculus 


In this section, we summarize the notation we use for matrix calculus (see Section 7.8 for details). 
Let 0 € R be a vector and f : RY > R be a scalar valued function. The derivative of f wrt its 
argument is denoted by the following: 


Vol) evro evre (He -- FE) (A.1) 


The gradient is a vector that must be evaluated at a point in space. To emphasize this, we will 
sometimes write 


g: = 9(%) = VF(8) (A.2) 


0: 


We can also compute the (symmetric) N x N matrix of second partial derivatives, known as the 
Hessian: 


af a? 
307 00100N 
vfs (A.3) 
af af 
ON 01 002, 


The Hessian is a matrix that must be evaluated at a point in space. To emphasize this, we will 
sometimes write 


H, = H(0,) ê V? f (0) (A.4) 
6, 


A.5 Optimization 


In this section, we summarize the notation we use for optimization (see Chapter 8 for details). 

We will often write an objective or cost function that we wish to minimize as £(@), where 0 are 
the variables to be optimized (often thought of as parameters of a statistical model). We denote 
the parameter value that achieves the minimum as 6, = argming.@ £(0), where © is the set we are 
optimizing over. (Note that there may be more than one such optimal value, so we should really 
write 6, € argminọco £(8).) 

When performing iterative optimization, we use t to index the iteration number. We use 7) as a 
step size (learning rate) parameter. Thus we can write the gradient descent algorithm (explained in 
Section 8.4) as follows: 0,1; = 6; — gt- 
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We often use a hat symbol to denote an estimate or prediction (e.g., Ô, ĝ), a star subscript or 
superscript to denote a true (but usually unknown) value (e.g., 0. or 0*), an overline to denote a 
mean value (e.g., 0). 


A.6 Probability 


In this section, we summarize the notation we use for probability theory (see Chapter 2 for details). 

We denote a probability density function (pdf) or probability mass function (pmf) by p, a cumulative 
distribution function (cdf) by P, and the probability of a binary event by Pr. We write p(X) for the 
distribution for random variable X, and p(Y) for the distribution for random variable Y — these 
refer to different distributions, even though we use the same p symbol in both cases. (In cases where 
confusion may arise, we write px(-) and py(-).) Approximations to a distribution p will often be 
represented by q, or sometimes /. 

In some cases, we distinguish between a random variable (rv) and the values it can take on. In this 
case, we denote the variable in upper case (e.g., X), and its value in lower case (e.g., x). However, 
we often ignore this distinction between variables and values. For example, we sometimes write p(x) 
to denote either the scalar value (the distribution evaluated at a point) or the distribution itself, 
depending on whether X is observed or not. 

We write X ~ p to denote that X is distributed according to distribution p. We write X LY | Z 
to denote that X is conditionally independent of Y given Z. If X ~ p, we denote the expected value 
of f(X) using 


l [FOO] = Ep) [F(X)] = Ex [f(*)] = | He)p(e)ae (A.5) 


If f is the identity function, we write X £ E [X]. Similarly, the variance is denoted by 


VIX) = Yw [FD] = Vx [FCO] = fiw - E [/(X)]) p(2)dz (A.6) 
If x is a random vector, the covariance matrix is denoted 

Cov [a] = E [(x — 2) (x — z)" ] (A.7) 
If X ~ p, the mode of a distribution is denoted by 


& = mode [p] = argmax p(x) (A.8) 


We denote parametric distributions using p(a|0@), where x are the random variables, 0 are the 
parameters and p is a pdf or pmf. For example, N(2|,07) is a Gaussian (normal) distribution with 
mean y and standard deviation ø. 


A.7 Information theory 


In this section, we summarize the notation we use for information theory (see Chapter 6 for details). 

If X ~ p, we denote the (differential) entropy of the distribution by H (X) or H (p). If Y ~q, 
we denote the KL divergence from distribution p to q by Dx (p || q). If (X,Y) ~ p, we denote the 
mutual information between X and Y by I(X;Y). 
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A.8 Statistics and machine learning 


We briefly summarize the notation we use for statistical learning. 


A.8.1 Supervised learning 


For supervised learning, we denote the observed features (also called inputs or covariates) by 
x € X. Often ¥ = R”, meaning the features are real-valued. (Note that this includes the case 
of discrete-valued inputs, which can be represented as one-hot vectors.) Sometimes we compute 
manually-specified features of the input; we denote these by @(a). We also have outputs (also called 
targets or response variables) y € y that we wish to predict. Our task is to learn a conditional 
probability distribution p(y|x, 0), where 0 are the parameters of the model. If Y = {1,...,C}, we 
call this classification. If Y = R©, we call this regression (often C = 1, so we are just predicting 
a scalar response). 

The parameters 0 are estimated from training data, denoted by D = {(a@n, yn): n € {1,...,N}} 
(so N is the number of training cases). If X = R?, we can store the training inputs in an N x D 
design matrix denoted by X. If Y = R®, we can store the training outputs in an N x C matrix Y. 
If Y = {1,...,C}, we can represent each class label as a C-dimensional bit vector, with one element 
turned on (this is known as a one-hot encoding), so we can store the training outputs in an N x C 
binary matrix Y. 


A.8.2 Unsupervised learning and generative models 


Unsupervised learning is usually formalized as the task of unconditional density estimation, namely 
modeling p(a|@). In some cases, we want to perform conditional density estimation; we denote the 
values we are conditioning on by u, so the model becomes p(æ|u, 0). This is similar to supervised 
learning, except that a is usually high dimensional (e.g., an image) and u is usually low dimensional 
(e.g., a class label or a text description). 

In some models, we have latent variables, also called hidden variables, which are never observed 
in the training data. We call such models latent variable models (LVM). We denote the latent 
variables for data case n by Zn € Z. Sometimes latent variables are known as hidden variables, 
and are denoted by hn. By contrast, the visible variables will be denoted by v,. Typically the 
latent variables are continuous or discrete, i.e., Z = R4 or Z = {1,..., K}. 

Most LVMs have the form p(£n, Zn|0); such models can be used for unsupervised learning. However, 
LVMs can also be used for supervised learning. In particular, we can either create a generative 
(unconditional) model of the form p(£n, Yn, Zn|@), or a discriminative (conditional) model of the 
form p(Yn, Zn|£n, 0). 


A.8.3 Bayesian inference 


When working with Bayesian inference, we write the prior over the parameters as p(@|€), where € 
are the hyperparameters. For conjugate models, the posterior has the same form as the prior (by 
definition). We can therefore just update the hyperparameters from their prior value, č, to their 
posterior value, £. 

In variational inference (Section 4.6.8.3), we use yw to represent the parameters of the variational 
posterior, i.e., p(@|D) ~ q(@|y). We optimize the ELBO wrt w to make this a good approximation. 
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When performing Monte Carlo sampling, we use a s subscript or superscript to denote a sample 
(e.g., Os or 0°). 


A.9 Abbreviations 


Here are some of the abbreviations used in the book. 


Abbreviation Meaning 


cdf Cumulative distribution function 
CNN Convolutional neural network 
DAG Directed acyclic graph 

DML Deep metric learning 

DNN Deep neural network 

dof Degrees of freedom 

EB Empirical Bayes 

EM Expectation maximization algorithm 
GLM Generalized linear model 

GMM Gaussian mixture model 

HMC Hamiltonian Monte Carlo 

HMM Hidden Markov model 

iid Independent and identically distributed 
iff If and only if 

KDE Kernel density estimation 

KL Kullback Leibler divergence 
KNN K nearest neighbor 

LHS Left hand side (of an equation) 
LSTM Long short term memory (a kind of RNN) 
LVM Latent variable model 

MAP Maximum A Posterior estimate 
MCMC Markov chain Monte Carlo 

MLE Maximum likelihood estimate 
MLP Multilayer perceptron 

MSE Mean squared error 

NLL Negative log likelihood 

OLS Ordinary least squares 

psd Positive definite (matrix) 

pdf Probability density function 

pmf Probability mass function 

PNLL Penalized NLL 

PGM Probabilistic graphical model 
RNN Recurrent neural network 

RHS Right hand side (of an equation) 
RSS Residual sum of squares 

rv Random variable 

RVM Relevance vector machine 
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SGD 
SSE 
SVI 
SVM 
VB 
wrt 
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Stochastic gradient descent 
Sum of squared errors 
Stochastic variational inference 
Support vector machine 
Variational Bayes 

With respect to 
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a trous algorithm, 484 
AMSGrap, 299 
ADADELTA, 298 
ADAGRAD, 297 

ADAM, 298 

PADAM, 299 
RMSProp, 297 
RPROP, 297 

Yoai, 299 

1x1 convolution, 472 


abstractive summarization, 540 
action potential, 434 

actions, 167 

activation function, 426 
activation maximization, 493 
active, 303 

active learning, 405, 648 

active set, 397 

activity regularization, 681 
Adaboost.M1, 613 
AdaBoostClassifier, 613 

Adam, 445 

Adamic/Adar, 756 

adapters, 629 

adaptive basis functions, 609 
adaptive instance normalization, 499 
adaptive learning rate, 297, 299 
add-one smoothing, 122, 132, 333 
additive attention, 519 

additive model, 609 

adjoint, 443 

adjusted Rand index, 715 
admissible, 191 

affine function, 8, 8 

agent, 17, 167 

aggregated gradient, 296 

AGI, 29 

AI, 28 

AI ethics, 28 

AI safety, 28 

Akaike information criterion, 185 
aleatoric uncertainty, 7, 34 
AlexNet, 479 

alignment, 520 

alignment problem, 28 

all pairs, 591 

all-reduce, 452 
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ALS, 741 

alternating least squares, 741 
alternative hypothesis, 179, 197 
ambient dimensionality, 688 
amortized inference, 683 
anchor, 553 

anchor boxes, 488 

ANN, 434 

Anscombe’s quartet, 43, 43 
approximate posterior inference, 151 
approximation error, 193 

ARD, 409, 568, 596 

ARD kernel, 568 

area under the curve, 173 
Armijo backtracking method, 284 
Armijo-Goldstein, 284 

artificial general intelligence, 28 
artificial intelligence, 28 
artificial neural networks, 434 
associative, 240 

asymptotic normality, 155 
asymptotically optimal, 160 
asynchronous training, 452 
atomic bomb, 73 

attention, 516, 520 

attention kernel, 652 

attention score, 517 

attention weight, 517 

AUC, 173 

augmented intelligence, 29 
auto-covariance matrix, 79 
AutoAugment, 625 
autocorrelation matrix, 79 
autodiff, 436 

autoencoder, 677 

automatic differentiation, 436 
automatic relevancy determination, 409, 568, 596 
AutoML, 483 

AutoRec, 743 

autoregressive model, 101 
average link clustering, 718 
average pooling, 473 

average precision, 175 

axis aligned, 82 

axis parallel splits, 601 

axon, 434 


B-splines, 397 
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backbone, 431 

backfitting, 400 
backpropagation, 426 
backpropagation algorithm, 436 
backpropagation through time, 508 
backslash operator, 265 
backsubstitution, 265, 373 

bag, 653 

bag of word embeddings, 26 
bag of words, 24, 430 

bagging, 607 

BALD, 649 

balloon kernel density estimator, 561 
band-diagonal matrix, 237 
bandwidth, 457, 558, 566 
Barnes-Hut algorithm, 702 
barycentric coordinates, 696 
base measure, 93 

basis, 231 

basis function expansion, 423 
basis vectors, 241 

batch learning, 119 

batch normalization, 475, 476 
batch renormalization, 476 
BatchBALD, 649 

Bayes decision rule, 168 

Bayes error, 127, 545 

Bayes estimator, 168, 190 
Bayes factor, 179 

Bayes model averaging, 129 
Bayes risk, 190 

Bayes rule, 45 

Bayes rule for Gaussians, 87 
Bayes’ rule, 45, 45 

Bayes’s rule, 45 

Bayesian, 33 

Bayesian x?-test, 187 

Bayesian active learning by disagreement, 649 
Bayesian decision theory, 167 
Bayesian deep learning, 455 
Bayesian factor regression, 675 
Bayesian inference, 44, 46 
Bayesian information criterion, 185 
Bayesian machine learning, 147 
Bayesian model selection, 185 
Bayesian network, 100 
Bayesian neural network, 455 
Bayesian Occam’s razor, 183 
Bayesian optimization, 648 
Bayesian personalized ranking, 745 
Bayesian statistics, 129, 154 
Bayesian t-test, 187 

BBO, 317 

Beam search, 513 

belief state, 46 

Berkson’s paradox, 102 
Bernoulli distribution, 49 
Bernoulli mixture model, 98 
BERT, 536 

Bessel function, 568 

beta distribution, 63, 121, 131 
beta function, 63 
beta-binomial, 135 

BFGS, 288 

bi-tempered logistic regression, 360 
bias, 8, 159, 369 

bias-variance tradeoff, 161 


BIC, 185 

BIC loss, 185 

BIC score, 185, 726 
biclustering, 735 
bidirectional RNN, 504 

big data, 3 

bigram model, 102, 210 
bijector, 66 

bilevel optimization, 194 
binary classification, 2, 46 
binary connect, 309 

binary cross entropy, 340 
binary entropy function, 206 
binary logistic regression, 337 
binomial coefficient, 50 
binomial distribution, 50, 50 
binomial regression, 414 
BinomialBoost, 616 

BIO, 539 

BiT, 530 

bits, 205 

bivariate Gaussian, 81 

black swan paradox, 122 
blackbox, 317 

blackbox optimization, 317 
block diagonal, 237 

block structured matrices, 248 
Blue Brain Project, 436 
BMA, 129 

BMM, 98 

BN, 475 

BNN, 455 

Boltzmann distribution, 55 
BookCrossing, 740 

Boolean logic, 34 

Boosting, 609 

bootstrap, 156 

bottleneck, 677 

bound optimization, 310, 352 
bounding boxes, 487 

bowl shape, 342 

box constraints, 306 

box plots, 44 

boxcar kernel, 558, 560 
branching factor, 210 

Brier score, 179, 643 
Brownian motion, 569 
byte-pair encoding, 26 


C-way N-shot classification, 651 
C4, 541 

C4.5, 603 

calculus, 267 

calculus of variations, 95 
calibration plot, 418 

canonical correlation analysis, 677 
canonical form, 94 

canonical link function, 415 
canonical parameters, 93 
CART, 601, 603 

Cartesian, 68 

Caser, 748 

CatBoost, 617 

categorical, 53 

categorical PCA, 675 

CatPCA, 675 

Cauchy, 62 
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causal, 80 

causal CNN, 515 

causal convolution, 516 
CBOW, 705, 705, 706 

CCA, 677 

cdf, 37, 57 

center, 406 

centering matrix, 113, 245, 694 
central interval, 146 

central limit theorem, 60, 72 
centroids, 457 

chain rule for entropy, 209 
chain rule for mutual information, 217 
chain rule of calculus, 271 
chain rule of probability, 39 
change of variables, 67 
channels, 465, 471 
characteristic equation, 252 
characteristic length scale, 568 
characteristic matrix, 220 
chatbots, 541 

ChatGPT, 540, 541 
Chi-squared distribution, 65 
Cholesky decomposition, 380 
Cholesky factorization, 264 
CIFAR, 20 

city block distance, 716 

class conditional density, 321 
class confusion matrix, 172 
class imbalance, 173, 357, 591 
class-balanced sampling, 357 
classes, 2 

classical MDS, 690 

classical statistics, 154 
classification, 2, 776 
Classification and regression trees, 601 
CLIP, 633 

closed world assumption, 548 
cloze, 537 

cloze task, 630 

cluster assumption, 638 
Clustering, 713 

clustering, 97 

clusters, 14 

CNN, 3, 424, 465 
co-adaptation, 453 
Co-training, 640 

coclustering, 735 

code generation, 541 
codebook, 722 

coefficient of determination, 379 
CoLA, 541 

cold start, 748 

collaborative filtering, 735, 740 
column rank, 235 

column space, 231 

column vector, 227 
column-major order, 229 
committee method, 606 
commutative, 240 
compactness, 718 

comparison of classifiers, 186 
complementary log-log, 416 
complementary slackness, 303 
complete link clustering, 718 
completing the square, 88 
complexity penalty, 121 


composite objective, 280 
compositional, 432 

compound hypothesis, 198 
computation graph, 442 

computer graphics, 491 

concave, 276 

condition number, 123, 236, 284 
conditional computation, 460 
conditional distribution, 38 
conditional entropy, 208 

conditional instance normalization, 498 
conditional mixture model, 459 
conditional mutual information, 217 
conditional probability, 35 


conditional probability distribution, 7, 50, 100 


conditional probability table, 100 
conditional variance formula, 42 
conditionally independent, 35, 39, 99 
confidence interval, 146, 157 
confirmation bias, 637 

conformer, 531 

conjugate function, 277 

conjugate gradient, 284, 373 
conjugate prior, 87, 130, 130, 131 
consensus sequence, 207 

conservation of probability mass, 183 
Consistency regularization, 642 
consistent estimator, 191, 404 
constrained optimization, 275 
constrained optimization problem, 299 
constrained optimization problems, 401 
constraints, 275 


contextual word embeddings, 26, 535, 710 


contingency table, 188 
continual learning, 549 
continuation method, 396 
continuous optimization, 273 
continuous random variable, 36 
contraction, 681 

contractive autoencoder, 680 
contradicts, 521 

contrastive loss, 553 
contrastive tasks, 631 

control variate, 295 

convex, 342 

convex combination, 133 
convex function, 276 

convex optimization, 275 
convex relaxation, 384 

convex set, 275 

ConvNeXt, 483, 531 
convolution, 70, 466 
convolution theorem, 70 
convolution with holes, 484 
convolutional Markov model, 515 
convolutional neural network, 3, 21 


convolutional neural networks, 12, 424, 465 


coordinate descent, 395 

coordinate vectors, 231 

coordinated based representations, 427 
coreference resolution, 525 

coreset, 557 

corpus, 703 

correlation coefficient, 78, 82 
correlation does not imply causation, 79 
correlation matrix, 78, 114 

cosine kernel, 569 
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cosine similarity, 704 

cost function, 273 

covariance, 77 

covariance matrix, 77, 81 
covariates, 2, 369, 776 
COVID-19, 46 

CPD, 100 

CPT, 100 

Cramer-Rao lower bound, 160 
credible interval, 143, 146, 146 
critical point, 301 

cross correlation, 467 

cross entropy, 207, 212, 214 
cross validation, 126, 195 
cross-covariance, 77 
cross-entropy, 178 

cross-over rate, 173 
cross-validated risk, 126, 195 
crosscat, 737 

crowding problem, 701 

cubic splines, 397 

cumulants, 95 

cumulative distribution function, 37, 57 
curse of dimensionality, 546 
curve fitting, 14 

curved exponential family, 94 
CV, 126, 195 

cyclic permutation property, 234 
cyclical learning rate, 294 


DAG, 100, 442 

data augmentation, 214, 625 
data compression, 16, 723 

data fragmentation, 603 

Data mining, 27 

data parallelism, 452 

data processing inequality, 221 
Data science, 27 

data uncertainty, 7, 34 
Datasaurus Dozen, 43, 44 

dead ReLU, 448 

debiasing, 387 

decision boundary, 5, 53, 149, 338 
decision making under uncertainty, 1 
decision rule, 5 

decision surface, 6 

decision tree, 6 

decision trees, 601 

decode, 655 

decoder, 677, 685 
deconvolution, 485 

deduction, 200 

deep CCA, 677 

deep factorization machines, 746 
deep graph infomax, 765 

deep metric learning, 550, 552 
deep mixture of experts, 461 
deep neural networks, 12, 423 
DeepDream, 495 

DeepWalk, 758 

default prior, 145 

defender’s fallacy, 75 

deflated matrix, 711 

deflation, 257 

degree of normality, 61 

degrees of freedom, 13, 61, 381 
delta rule, 292 


demonstrations, 18 

dendrites, 434 

dendrogram, 716 

denoising autoencoder, 679 
dense prediction, 490 

dense sequence labeling, 505 
DenseNets, 482 

density estimation, 16 

density kernel, 518, 558 
dependent variable, 369 

depth prediction, 490 

depthwise separable convolution, 486 
derivative, 267 

derivative free optimization, 317 
descent direction, 281, 282 
design matrix, 3, 243, 424, 776 
determinant, 235 

development set, 124 

deviance, 418, 604 

DFO, 317 

diagonal covariance matrix, 83 
diagonal matrix, 237 
diagonalizable, 253 

diagonally dominant, 238 
diameter, 718 

differentiable programming, 442 
differential entropy, 210 
differentiating under the integral sign, 70 
differentiation, 268 

diffuse prior, 145 

dilated convolution, 484, 489 
dilation factor, 484 
dimensionality reduction, 4, 655 
Dirac delta function, 60, 148 
directed acyclic graph, 100, 442 
directional derivative, 268 
Dirichlet distribution, 138, 332 
Dirichlet energy, 699 

discrete AdaBoost, 612 

discrete optimization, 273 
discrete random variable, 35 
discretize, 211, 218 
discriminant function, 322 
discriminative classifier, 321, 334 
dispersion parameter, 413 
distance metric, 211 

distant supervision, 653 
distortion, 655, 720, 722 
distributional hypothesis, 703 
distributive, 240 

divergence measure, 211 

diverse beam search, 514 

DNA sequence motifs, 206 
DNN, 12, 423 

document retrieval, 704 
document summarization, 22 
domain adaptation, 628, 635 
domain adversarial learning, 635 
dominates, 191, 198 

dot product, 240 

double centering trick, 245 
double sided exponential, 63 
dropout, 453 

dual feasibility, 303 

dual form, 586 

dual problem, 585 

dual variables, 592 
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dummy encoding, 23, 53 
Dutch book theorem, 202 
dynamic graph, 444 
dynamic programming, 513 


E step, 310, 312 

early stopping, 126, 453 
EB, 145 

echo state network, 509 
ECM, 668 

economy sized QR, 263 
economy sized SVD, 258 
edge devices, 309, 486 
EER, 173 

effect size, 186 
EfficientNetv2, 483 
eigenfaces, 657 

eigenvalue, 251 

eigenvalue decomposition, 251 
eigenvalue spectrum, 123 
eigenvector, 251 

Einstein summation, 246 
einsum, 246 

elastic embedding, 700 
elastic net, 388, 394 

ELBO, 153, 312, 683 
elbow, 726 

electronic health records, 521 
ell-2 loss, 9 

ELM, 582 

ELMo, 536 

ELU, 448 

EM, 85, 310 

EMA, 119 

email spam classification, 22 
embarassingly parallel, 452 
embedding, 655 

EMNIST, 20 

empirical Bayes, 145, 182, 409 


empirical distribution, 66, 72, 109, 192, 213 


empirical risk, 6, 192 


empirical risk minimization, 7, 115, 193, 291 


encode, 655 
encoder, 677, 685 
encoder-decoder, 489 


encoder-decoder architecture, 507 


endogenous variables, 2 
energy based model, 633 
energy function, 152 
ensemble, 294, 455 
ensemble learning, 606 
entails, 521 

entity discovery, 720 

entity linking, 549 

entity resolution, 540, 549 
entropy, 178, 205, 312, 604 
entropy minimization, 637 
entropy SGD, 456 
Epanechnikov kernel, 559 
epigraph, 276 

epistemic uncertainty, 7, 34 
epistemology, 34 

epoch, 291 


epsilon insensitive loss function, 593 


equal error rate, 173 
equality constraints, 275, 300 
equitability, 220 


equivalent sample size, 131 
equivariance, 473 

ERM, 115, 193 

error function, 57 
estimation error, 193 
estimator, 154 

EVD, 251 

event, 34, 34, 35, 35 
events, 33 

evidence, 135, 181 


evidence lower bound, 153, 312, 683 


EWMA, 119, 297 

exact line search, 284 
exchangeable, 102 

exclusive KL, 214 

exclusive or, 458 
exemplar-based models, 545 
exemplars, 457 

exogenous variables, 2 
expectation maximization, 310 


expected complete data log likelihood, 313 
expected sufficient statistics, 313 


expected value, 40, 58 
experiment design, 648 
explaining away, 102 
explanatory variables, 369 
explicit feedback, 739 

exploding gradient problem, 445 


exploration-exploitation tradeoff, 749 


exploratory data analysis, 4 


exponential dispersion family, 413 


Exponential distribution, 64 
exponential family, 93, 96, 144 


exponential family factor analysis, 673 


exponential family PCA, 673 
Exponential linear unit, 446 
exponential loss, 612 
exponential moving average, 119 


exponentially weighted moving average, 119 


exponentiated cross entropy, 210 


exponentiated quadratic kernel, 566 


extractive summarization, 540 
extreme learning machine, 582 


F score, 175 

face detection, 487 

face recognition, 487 

face verification, 549 
FaceNet, 555 

factor analysis, 16, 664 
factor loading matrix, 665 
factorization machine, 746 
FAISS, 548 

false alarm rate, 172 

false negative rate, 46 
false positive rate, 46, 172 
fan-in, 451 

fan-out, 451 

Fano’s inequality, 223 
farthest point clustering, 723 
Fashion-MNIST, 20 

fast adapation, 650 

fast Hadamard transform, 582 
fastfood, 582 

feasibility, 302 

feasibility problem, 275 
feasible set, 275 
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feature crosses, 24 

feature detection, 469 

feature engineering, 11 

feature extraction, 11 

feature extractor, 370 

feature importance, 619, 619 
feature map, 469 

feature preprocessing, 11 

feature selection, 223, 308, 383 
features, 1 

featurization, 4 

feedforward neural network, 423 
few-shot classification, 549 
few-shot learning, 651 

FFNN, 423 

fill in, 27 

fill-in-the-blank, 537, 630 

filter, 467 

filter response normalization, 477 
filters, 465 

FIM, 155 

fine-grained classification, 21, 651 
fine-grained visual classification, 627 
fine-tune, 535 

fine-tuning phase, 627 

finite difference, 268 

finite sum problem, 291 

first order, 344 

first order Markov condition, 101 
first-order, 280, 287 

Fisher information matrix, 155, 346 
Fisher scoring, 346 

Fisher’s linear discriminant analysis, 326 
FISTA, 396 

FLAN-T5, 541 

flat local minimum, 274 

flat minima, 455 

flat prior, 144 

flatten, 429 

FLDA, 326 

folds, 126, 195 

forget gate, 511 

forward mode differentiation, 437 
forward stagewise additive modeling, 610 
forwards KL, 214 

forwards model, 49 

founder variables, 671 

fraction of variance explained, 663 
fraud detection system, 768 
frequentist, 33 

frequentist decision theory, 188 
frequentist statistics, 154 
Frobenius norm, 234 

frozen parameters, 628 

full covariance matrix, 82 

full rank, 235 

full-matrix Adagrad, 299 
function space, 575 

furthest neighbor clustering, 718 
fused batchnorm, 475 


gallery, 488, 548 

GAM, 399 

gamma distribution, 64 

GANs, 635 

Gated Graph Sequence Neural Networks, 760 
gated recurrent units, 510 
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gating function, 460 

Gaussian, 9 

Gaussian discriminant analysis, 321 
Gaussian distribution, 57 

Gaussian kernel, 457, 533, 558, 566 
Gaussian mixture model, 97 
Gaussian process, 458 

Gaussian process regression, 399 
Gaussian processes, 572 

Gaussian scale mixture, 105 

GCN, 761 

GDA, 321 

GELU, 446, 449 

generalization error, 193, 195 
generalization gap, 13, 193 
generalize, 7, 121 

generalized additive model, 399 
generalized CCA, 677 

generalized eigenvalue, 328 
generalized Lagrangian, 302, 585 
generalized linear models, 413 
generalized low rank models, 673 
generalized probit approximation, 365 
Generative adversarial networks, 645 
generative classifier, 321, 334 
generative image model, 491 
Geometric Deep Learning, 751 
geometric series, 120, 285 

Gini index, 603 

glmnet, 395 

GLMs, 413 

global average pooling, 430, 474 
global optimization, 273 

global optimum, 273, 342 

globally convergent, 274 

Glorot initialization, 451 

GloVe, 707 

GMM, 97 

GMRES, 374 

GNN, 424, 760 

goodness of fit, 378 

GoogLeNet, 480 

GPT, 540 

GPT-2, 540 

GPT-3, 540 

GPUs, 433, 452 

GPyTorch, 581 

gradient, 268, 281 

gradient boosted regression trees, 616 
gradient boosting, 614 

gradient clipping, 445 

gradient sign reversal, 635 

gradient tree boosting, 616 

Gram matrix, 239, 245, 498, 566 
Gram Schmidt, 240 

Graph attention network, 762 
Graph convolutional networks, 761 
graph factorization, 757 

graph Laplacian, 697, 733 

graph neural network, 760 

graph neural networks, 424 

graph partition, 732 

Graph Representation Learning, 751 
graphical models, 40 

Graphical Mutual Information, 766 
graphics processing units, 452 
graphite, 765 
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GraphNet, 761 hyper-parameters, 131, 317 

GraphSAGE, 761 hypercolumn, 472 

GraRep, 757 hypernyms, 356 

greedy decoding, 513 hyperparameter, 194 

greedy forward selection, 397 hyperparameters, 145 

grid approximation, 152 hyperplane, 338 

grid search, 125, 318 hypothesis, 521 

group lasso, 392 hypothesis space, 193 

group normalization, 477 hypothesis testing, 179 

group sparsity, 391 

grouping effect, 394 Lprojection, 214 

GRU, 510 IA, 29 

Gshard, 531 ID3, 603 

Gumbel noise, 514 identifiability, 353 
identifiable, 191, 353 

HAC, 715 identity matrix, 237 

half Cauchy, 63 iid, 71, 108, 130 

half spaces, 338 ill-conditioned, 114, 236 

Hamiltonian Monte Carlo, 154 ill-posed, 49 

hard attention, 523 ILP, 305 

hard clustering, 98, 728 ILSVRC, 21 

hard negatives, 554 image captioning, 503 

hard thresholding, 386, 389 image classification, 3 

hardware accelerators, 435 image compression, 723 

harmonic mean, 175 image interpolation, 686 

hat matrix, 373 image patches, 465 

HDI, 147 image tagging, 348, 486 

He initialization, 451 image-to-image, 490 

heads, 431 ImageNet, 21, 433, 479 

heat map, 469 ImageNet-21k, 531 

Heaviside, 344 IMDB, 128 

Heaviside step function, 441 IMDB movie review dataset, 22 

heaviside step function, 52, 424 implicit feedback, 745 

heavy ball, 285 implicit regularization, 455 

heavy tails, 62, 400 impostors, 550 

Helmholtz machine, 683 imputation tasks, 630 

Hessian, 342, 774 inception block, 480 

Hessian matrix, 270 Inceptionism, 495 

heteroskedastic regression, 59, 374 inclusive KL, 214 

heuristics, 436 incremental learning, 549 

hidden, 45 indefinite, 238 

hidden common cause, 79 independent, 35, 39 

hidden units, 425 independent and identically distributed, 71, 130 

hidden variables, 102, 310, 776 independent variables, 369 

hierarchical, 432 indicator function, 6, 36 

hierarchical agglomerative clustering, 715 induced norm, 233 

hierarchical Bayesian model, 145 inducing points, 581 

hierarchical mixture of experts, 462 induction, 122, 200 

hierarchical softmax, 356 inductive bias, 13, 425, 530 

hierarchy, 355 inductive learning, 641 

highest density interval, 147 inequality constraints, 275, 300 

highest posterior density, 147 infeasible, 304 

hinge loss, 116, 318, 589, 745 inference, 107, 129 

Hinton diagram, 86 inference network, 682 

hit rate, 172 infinitely wide, 580 

HMC, 154 InfoNCE, 554 

Hoeffding’s inequality, 196 information, 33 

hogwild training, 452 information content, 205 

holdout set, 194 information criteria, 185 

homogeneous, 101 information diagram, 216 

homoscedastic regression, 59 information diagrams, 217 

homotopy, 396 information extraction, 539 

HPD, 147 information gain, 211, 648 

Huber loss, 177, 402, 615 information gathering action, 7 

Huffman encoding, 356 information projection, 214 

human pose estimation, 491 information retrieval, 173 

Hutchinson trace estimator, 234, 235 information theory, 178, 205 
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inner product, 240 

input gate, 511 

Instagram, 487 

instance normalization, 476 
instance segmentation, 488 
instance-balanced sampling, 357 
instance-based learning, 545 
InstructGPT, 540 

instruction fine-tuning, 541 
Integer linear programming, 305 
integrated risk, 190 

integrating out, 129 

intelligence augmentation, 29 
inter-quartile range, 44 
interaction effects, 23 

intercept, 8 

interior point method, 304 
internal covariate shift, 475 
interpolate, 11 

interpolated precision, 175 
interpolator, 572 

interpretable, 17 

intrinsic dimensionality, 688 
inverse, 247 

inverse cdf, 38 

inverse document frequency, 25 
inverse Gamma distribution, 65 
inverse probability, 49 

inverse problems, 49 

inverse reinforcement learning, 28 
inverse Wishart, 122 

Tris, 2 

Iris dataset, 3 

IRLS, 346 

isomap, 692 

isotropic covariance matrix, 83 
ISTA, 396 

items, 739 

iterate averaging, 295 

iterative soft thresholding algorithm, 396 
iteratively reweighted least squares, 346 


Jacobian, 68, 350 

Jacobian formulation, 269 
Jacobian matrix, 269 
Jacobian vector product, 269 
Jensen’s inequality, 212, 312 
Jeopardy, 170 

Jester, 740 

JFT, 531 

jittered, 458 

joint distribution, 38 

joint probability, 34 

JPEG, 723 

just in time, 444 

JVP, 269 


K nearest neighbor, 545 
k-d tree, 548 

K-means algorithm, 720 
K-means clustering, 98 
K-means++, 723 
K-medoids, 723 

Kalman filter, 91 

Karl Popper, 122 
Karush-Kuhn-Tucker, 303 
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Katz centrality index, 756 

KDE, 558, 560 

kernel, 467, 558 

kernel density estimation, 558 
kernel density estimator, 560 
kernel function, 457, 565, 565 
kernel PCA, 693, 735 

kernel regression, 518, 562, 574 
kernel ridge regression, 574, 593 
kernel smoothing, 562 

kernel trick, 588 

keys, 516 

keywords, 259 

kink, 726 

KKT, 303 

KL divergence, 109, 178, 211, 312 
KNN, 545 

knots, 397 

Knowledge distillation, 647 
knowledge graph, 768 

Kronecker product, 246 

Krylov subspace methods, 581 
KSG estimator, 218 

Kullback Leibler divergence, 109, 178 
Kullback-Leibler divergence, 211, 312 


L-BFGS, 289 

LO-norm, 383, 384 

L1 loss, 177 

L1 regularization, 383 

L1VM, 596 

L2 loss, 176 

L2 regularization, 123, 347, 379 
L2VM, 595 

label, 2 

label noise, 358, 653 

Label propagation, 641, 759 
label smearing, 356 

label smoothing, 648, 653 

Label spreading, 759 

label switching problem, 317, 730 
Lagrange multiplier, 301 
Lagrange multipliers, 95, 111 
Lagrange notation, 268 
Lagrangian, 95, 257, 275, 301, 384 
Lanczos algorithm, 669 

language model, 101 

language modeling, 23, 503 
language models, 209, 535 
Laplace, 383 

Laplace approximation, 152, 361 
Laplace distribution, 63 

Laplace smoothing, 333 

Laplace vector machine, 596 
Laplace’s rule of succession, 134 
Laplacian eigenmaps, 696, 734, 755 
LAR, 397 

Large language models, 541 
large margin classifier, 583 

large margin nearest neighbor, 550 
LARS, 396 

lasso, 305, 383 

latent coincidence analysis, 551 
latent factors, 16, 657 

latent semantic analysis, 705 
latent semantic indexing, 704 
latent space interpolation, 686 
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latent variable, 96 

latent variable models, 776 
latent variables, 776 

latent vector, 657 

law of iterated expectations, 41 
law of total expectation, 41 

law of total variance, 42 

layer normalization, 476 
layer-sequential unit-variance, 451 
LCA, 551 

LDA, 321, 323 

Leaky ReLU, 446 

leaky ReLU, 448 

learning curve, 127 

learning rate, 281 

learning rate schedule, 282, 292, 293 
learning rate warmup, 294 
learning to learn, 650 

learning with a critic, 18 
learning with a teacher, 18 

least angle regression, 397 

least favorable prior, 191 

least mean squares, 292, 376 
least squares boosting, 397, 610 
least squares objective, 266 
least squares solution, 10 
leave-one-out cross-validation, 126, 195 
LeCun initialization, 451 

left pseudo inverse, 267 

Leibniz notation, 268 

LeNet, 474, 477 

level sets, 82, 83 

life-long learning, 549 
LightGBM, 617 

likelihood, 45 

likelihood function, 129 
likelihood principle, 201 
likelihood ratio, 179, 200 
likelihood ratio test, 197 
limited memory BFGS, 289, 352 
line search, 283 

Linear algebra, 227 

linear autoencoder, 678 

linear combination, 241 

linear discriminant analysis, 321, 323 
linear function, 8 

linear Gaussian system, 86 
linear kernel, 579 

linear map, 231 

linear operator, 374 

linear programming, 401 

linear rate, 284 

linear regression, 59, 369, 413, 423 
linear subspace, 241 

linear threshold function, 424 
linear transformation, 231 
linearity of expectation, 40 
linearly dependent, 230 

linearly independent, 230 
linearly separable, 338 
Linformer, 533 

link function, 413, 415 

link prediction, 767 

Lipschitz constant, 279 

liquid state machine, 509 
LLMs, 541 

LMNN, 550 


LMS, 292 

local linear embedding, 695 
local maximum, 274 

local minimum, 273 

local optimum, 273, 342 
locality sensitive hashing, 548 
locally linear regression, 563 


locally-weighted scatterplot smoothing, 563 


LOESS, 563 

log bilinear language model, 709 
log likelihood, 108 

log loss, 179, 590 

log odds, 52 

log partition function, 93 
log-sum-exp trick, 56 
logistic, 51 

logistic function, 52, 148 
Logistic regression, 337 
logistic regression, 8, 53, 148, 413 
logit, 51, 337, 340 

logit adjustment, 357 

logit function, 52 

logit Boost, 614 

logits, 7, 55, 348 

long short term memory, 511 
long tail, 151, 356 

Lorentz, 62 

Lorentz model, 756 

loss function, 6, 9, 167, 273 
lossy compression, 722 

lower triangular matrix, 238 
LOWESS, 563 

LSA, 705 

lse, 56 

LSH, 548 

LSI, 704 

LSTM, 511 


M step, 310 

M’th order Markov model, 101 
M-projection, 214 

M1, 644 

M2, 644 

machine learning, 1 

machine translation, 22, 507 
Mahalanobis distance, 83, 545 
Mahalanobis whitening, 256 
main effects, 23 
majorize-minimize, 310 
MALA, 491 

MAML, 650 

manifold, 687, 687 

manifold assumption, 641 
manifold hypothesis, 687 
manifold learning, 687 

mAP, 175 

MAP estimate, 169 

MAP estimation, 121, 194 
MAR, 27 

margin, 116, 583, 612 
margin errors, 588 

marginal distribution, 38 


marginal likelihood, 45, 129, 135, 137, 145, 181 


marginalizing out, 129, 148 
marginalizing over, 129 
marginally independent, 39 
Markov chain, 101 
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Markov chain Monte Carlo, 153 
Markov kernel, 101 

Markov model, 101 

MART, 616 

masked attention, 518 

masked language model, 537 
matched filter, 466 

matching network, 652 

Matern kernel, 568 

matrix, 227 

matrix completion, 741 

matrix determinant lemma, 250 
matrix factorization, 741 

matrix inversion lemma, 249, 592 
matrix square root, 233, 243, 264 
matrix vector multiplication, 581 
max pooling, 473 

maxent classifier, 355 

maximal information coefficient, 219 
maximum a posterior estimation, 121 
maximum a posteriori, 169 
maximum entropy, 60, 205 
maximum entropy classifer, 355 
maximum entropy model, 95 
maximum entropy sampling, 649 
maximum expected utility principle, 168 
maximum likelihood estimate, 8 
maximum likelihood estimation, 107 
maximum risk, 190 

maximum variance unfolding, 695 
MCAR, 27 

McCulloch-Pitts model, 434 
McKernel, 582 

MCMC, 153 

MDL, 186 

MDN, 461 

MDS, 689 

mean, 40, 58 

mean average precision, 175 

mean function, 413 

mean squared error, 9, 115 

mean value imputation, 27 
median, 38, 57 

median absolute deviation, 561 
medoid, 723 

memory cell, 511 

memory-based learning, 545 
Mercer kernel, 565 

Mercer’s theorem, 566 

message passing neural networks, 760 
meta-learning, 650, 652 

method of moments, 117 

metric MDS, 691 
Metropolis-adjusted Langevin algorithm, 491 
MICe, 220 

min-max scaling, 348 

minibatch, 291 

minimal, 93 

minimal representation, 94 
minimal sufficient statistic, 222 
minimally informative prior, 145 
minimax estimator, 190 

minimum description length, 186 
minimum mean squared error, 176 
minimum spanning tree, 717 
minorize-maximize, 310 

MIP, 305 
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misclassification rate, 6, 116 
missing at random, 27, 739 
missing completely at random, 27 
missing data, 27, 310 

missing data mechanism, 27, 645 
missing value imputation, 85 
mixed ILP, 305 

mixing weights, 137 

mixmatch, 638 

mixture density network, 461 
mixture model, 96 

mixture of Bernoullis, 98 

mixture of beta distributions, 136 
mixture of experts, 460, 531 
mixture of factor analysers, 672 
mixture of Gaussians, 97 

ML, 1 

MLE, 8, 107 

MLP, 423, 425 

MLP-mixer, 425 

MM, 310 

MMSE, 176 

MNIST, 19, 477 

MobileNet, 486 

MoCo, 633 

mode, 41, 169 

mode-covering, 214 
mode-seeking, 215 

model compression, 453 

model fitting, 7, 107 

model parallelism, 452 

model selection, 180 

model selection consistent, 390 
model uncertainty, 7, 34 
model-agnostic meta-learning, 650 
modus tollens, 200 

MoE, 460 

MoG, 97 

moment projection, 214 
momentum, 285 

momentum contrastive learning, 633 
MoNet, 763 

Monte Carlo approximation, 72, 153, 364 
Monte Carlo dropout, 455 

Monty Hall problem, 47 
Moore-Penrose pseudo-inverse, 259 
most powerful test, 198 

motes, 11 

MovieLens, 740 

moving average, 119 

MSE, 9, 115 

multi-class classification, 348 
multi-clust, 737 
Multi-dimensional scaling, 755 
multi-headed attention, 525 
multi-instance learning, 653 
multi-label classification, 348 
multi-label classifier, 356 
multi-level model, 145 
multi-object tracking, 549 
multiclass logistic regression, 337 
multidimensional scaling, 689 
multilayer perceptron, 423, 425 
multimodal, 41 

multinomial coefficient, 54 
multinomial distribution, 53, 54 
Multinomial logistic regression, 348 
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multinomial logistic regression, 55, 337 
multinomial logit, 54 

multiple imputation, 85 

multiple linear regression, 10, 369 
multiple restarts, 723 

multiplicative interaction, 516 
multivariate Bernoulli naive Bayes, 330 
multivariate Gaussian, 80 

multivariate linear regression, 370 
multivariate normal, 80 

mutual information, 79, 215 

mutually independent, 74 

MVM, 581 

MVN, 80 

myopic, 649 


N-pairs loss, 554 
Nadaraya-Watson, 562 

naive Bayes assumption, 325, 330 
naive Bayes classifier, 330 

named entity recognition, 539 
NAS, 483 

nats, 205 

natural exponential family, 94 
natural language inference, 521 
natural language processing, 22, 355 
natural language understanding, 49 
natural parameters, 93 

NBC, 330 

NCA, 550 

NCM, 326 

nearest centroid classifier, 326 
nearest class mean classifier, 326, 357 
nearest class mean metric learning, 326 
nearest neighbor clustering, 717 
NEF, 94 

negative definite, 238 

negative log likelihood, 8, 108 
negative semidefinite, 238 
neighborhood components analysis, 550 
neocognitron, 474 

nested optimization, 194 

nested partitioning, 737 

Nesterov accelerated gradient, 286 
Netflix Prize, 739 

NetMF, 758 

neural architecture search, 483 
neural implicit representations, 427 
neural language model, 102 

neural machine translation, 507 
neural matrix factorization, 747 
neural style transfer, 495 

neural tangent kernel, 580 
NeurIPS, 18 

neutral, 521 

Newton’s method, 287, 345 

next sentence prediction, 538 
Neyman-Pearson lemma, 198 
NHST, 199 

NHWC, 472 

NIPS, 18 

NLL, 108 

NLP, 22 

NMAR, 27 

no free lunch theorem, 13 
node2vec, 758 

noise floor, 127 


non-identifiability, 730 
non-identifiable, 408 
non-metric MDS, 691 
non-parametric bootstrap, 157 
non-parametric methods, 687 
non-parametric model, 458 


non-saturating activation functions, 427, 447 


noninformative, 144 

nonlinear dimensionality reduction, 687 
nonlinear factor analysis, 672 
nonparametric methods, 565 
nonparametric models, 545 

nonsmooth optimization, 279 

norm, 232, 236 

normal, 9 

normal distribution, 57 

normal equations, 267, 371 


Normal-Inverse-Wishart distribution, 316 


normalization layers, 474 
normalized, 239 
normalized cut, 733 


normalized mutual information, 219, 715 


normalizer-free networks, 477 
Normalizing flows, 646 

not missing at random, 27 
noun phrase chunking, 539 
novelty detection, 549 
NT-Xent, 554 

nu-SVM classifier, 588 

nuclear norm, 233 

nucleotide, 206 

null hypothesis, 179, 186, 197 
null hypothesis significance testing, 199 
nullspace, 231 

numerator layout, 269 


object detection, 487 
objective, 144 

objective function, 108, 273 
observation distribution, 45 
Occam factor, 185 

Occam’s razor, 182 

offset, 10, 369 

Old Faithful, 314 

Olivetti face dataset, 656 
OLS, 115, 267, 371 
one-cycle learning rate schedule, 294 
one-hot, 178 

one-hot encoding, 23, 350, 776 
one-hot vector, 53, 227 
one-shot learning, 326, 651 
one-sided p-value, 199 
one-sided test, 186 
one-standard error rule, 126 
one-to-many functions, 459 
one-versus-one, 591 
one-versus-the-rest, 591 
one-vs-all, 591 

online learning, 119, 309, 549 
OOD, 549 

OOV, 24, 26 

open class, 26 

open set recognition, 548 
open world, 487 

open world assumption, 549 
OpenPose, 491 

opt-einsum, 247 
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optimal policy, 168 

optimism of the training error, 194 
optimization problem, 273 

order, 229 

order statistics, 118 

ordered Markov property, 100 
ordering constraint, 731 


ordinary least squares, 115, 267, 371 


Ornstein-Uhlenbeck process, 569 
orthodox statistics, 154 
orthogonal, 239, 253 

orthogonal projection, 373 
orthogonal random features, 582 
orthonormal, 239, 253 

out of vocabulary, 26 

out-of-bag instances, 607 
out-of-distribution, 549 
out-of-sample generalization, 687 
out-of-vocabulary, 24 

outer product, 241 

outliers, 61, 177, 358, 400 
output gate, 511 

over-complete representation, 94 
over-parameterized, 55 
overcomplete representation, 677 
overdetermined, 264 
overdetermined system, 372 
overfitting, 13, 120, 133 


p-value, 180, 199 

PAC learnable, 195 

PageRank, 256 

pair plot, 4 

paired test, 186 

pairwise independent, 73 

PAM, 724 

panoptic segmentation, 489 
parameter space, 273 

parameter tying, 101 
parameters, 6 

parametric bootstrap, 157 
parametric models, 545 
parametric ReLU, 448 

part of speech tagging, 539 
part-of-speech, 536 

partial dependency plot, 621 
partial derivative, 268 

partial least squares, 676 

partial pivoting, 262 

partial regression coefficient, 375 
partially observed, 167 

partition function, 56, 93, 633 
partitioned inverse formulae, 248 
partitioning around medoids, 724 


Parzen window density estimator, 560 


pathologies, 201 

pattern recognition, 2 
PCA, 16, 655, 656 

PCA whitening, 255 

pdf, 37, 58 

peephole connections, 512 
penalty term, 307 
percent point function, 38 
perceptron, 344, 424 
perceptron learning algorithm, 344 
Performer, 533 

periodic kernel, 569 


permutation test, 199 
perplexity, 209, 503 

person re-identification, 549 
PersonLab, 491 

perturbation theory, 734 

PGM, 100 

Planetoid, 766 

plates, 103 

Platt scaling, 589 

PLS, 676 

plug-in approximation, 133, 148 
plugin approximation, 364 
PMF, 742 

pmf, 36 

PMI, 705 

Poincaré model, 756 

point estimate, 107 

point null hypothesis, 186 
pointwise convolution, 472 
pointwise mutual information, 705 
Poisson regression, 415 

polar, 68 

policy, 17 

Polyak-Ruppert averaging, 295 
polynomial expansion, 370 
polynomial regression, 10, 123, 370 
polytope, 303 

pool-based active learning, 648 
population risk, 13, 125, 192 
POS, 536 

position weight matrix, 206, 207 
positional embedding, 526 
positive definite, 238 

positive definite kernel, 565 
positive PMI, 705 

positive semidefinite, 238 
post-norm, 528 

posterior, 129 

posterior distribution, 46, 129 
posterior expected loss, 167 
posterior inference, 46 
posterior mean, 176 

posterior median, 177 


posterior predictive distribution, 129, 134, 148, 364, 404 


power, 198 

power method, 256 

PPCA, 666 

ppf, 38 

pre-activation, 337 
pre-activations, 426 
pre-norm, 528 

pre-train, 535 

pre-trained word embedding, 26 
pre-training phase, 627 
preactivation resnet, 482 
precision, 57, 141, 174, 174 
precision at K, 174 
precision matrix, 84, 112 
precision-recall curve, 174 
preconditioned SGD, 296 
preconditioner, 296 
preconditioning matrix, 296 
predictive analytics, 27 
predictors, 2 

preferences, 167 

premise, 521 

PreResnet, 482 
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pretext tasks, 631 
prevalence, 46, 175 
primal problem, 585 
primal variables, 593 


principal components analysis, 16, 655 


principal components regression, 382 
prior, 121, 129 

prior distribution, 45 

probabilistic forecasting, 177 
probabilistic graphical model, 100 
probabilistic inference, 46 
probabilistic matrix factorization, 742 
probabilistic PCA, 655 

probabilistic perspective, 1 
probabilistic prediction, 177 


probabilistic principal components analysis, 666 


probability density function, 37, 58 
probability distribution, 177 
probability distributions, 1 
probability mass function, 36 
probability simplex, 138 
probability theory, 45 

probably approximately correct, 195 
probit approximation, 365 
probit function, 57, 365 

probit link function, 416 
product rule, 39 

product rule of probability, 45 
profile likelihood, 663 

profile log likelihood, 664 
projected gradient descent, 307, 396 
projection, 232 

projection matrix, 373 
prompt, 540 

prompt engineering, 541, 635 
proper scoring rule, 179 
prosecutor’s fallacy, 75 
proxies, 555 

proximal gradient descent, 396 
proximal gradient method, 306 
proximal operator, 306 
ProxQuant, 309 

proxy tasks, 631 

prune, 604 

psd, 238 

pseudo counts, 131, 332 
pseudo inputs, 581 

pseudo inverse, 371 

pseudo norm, 233 
pseudo-labeling, 637 
pseudo-likelihood, 537 

pure, 604 

purity, 714 

Pythagoras’s theorem, 266 


QALY, 167 

QP, 304 

quadratic approximation, 152 
quadratic discriminant analysis, 322 
quadratic form, 238, 254 
quadratic kernel, 566 

quadratic loss, 9, 176 

quadratic program, 304, 384, 594 
quality-adjusted life years, 167 
quantile, 38, 57 

quantile function, 38 
quantization, 210 


quantize, 211, 218 
quantized, 309 

quartiles, 38, 57 
Quasi-Newton, 288 
quasi-Newton, 352 

queries, 516 

query synthesis, 648 
question answering, 22, 540 


radial basis function, 558 

radial basis function kernel, 457, 533 
Rand index, 714 

RAND-WALK, 709 

random finite sets, 549 

random forests, 608 

random Fourier features, 582 
random number generator, 72 
random shuffling, 291 

random variable, 35 

random variables, 1 

random walk kernel, 571 

range, 231 

rank, 229, 235 

rank deficient, 235 

rank one update, 249 

rank-nullity theorem, 260 

ranking loss, 553, 745 

RANSAC, 402 

rate, 484, 723 

rate of convergence, 284 

rating, 739 

Rayleigh quotient, 256 

RBF, 558 

RBF kernel, 457, 566 

RBF network, 457 

real AdaBoost, 612 

recall, 172, 174 

receiver operating characteristic, 173 
receptive field, 471, 484 
recognition network, 682 
recommendation systems, 768 
Recommender systems, 739 
reconstruction error, 655, 657, 722 
Rectified linear unit, 446 

rectified linear unit, 427, 447 
recurrent neural network, 501 
recurrent neural networks, 12, 424 
recursive update, 119 

recursively, 375 
reduce-on-plateau, 294 

reduced QR, 263 

Reformer, 533 

region of practical equivalence, 186 
regression, 8, 776 

regression coefficient, 375 
regression coefficients, 8, 369 
regularization, 121 

regularization parameter, 121 
regularization path, 382, 387, 396 
regularized discriminant analysis, 325 
regularized empirical risk, 193 
reinforcement learning, 17, 749 


reinforcement learning from human feedback, 540 


reject option, 170 

reject the null hypothesis, 199 
relational data, 739 

relative entropy, 211 
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relevance vector machine, 596 
ReLU, 427, 447 
reparameterization trick, 684 
representation learning, 631 
reservoir computing, 510 

reset gate, 511 

reshape, 229 

residual block, 449, 481 
residual error, 115 

residual network, 449 

residual plot, 378 

residual sum of squares, 115, 371 
residuals, 9, 378 

ResNet, 449, 481 

ResNet-18, 482 

response, 2 

response variables, 776 
responsibility, 98, 313, 461 
reverse KL, 214 

reverse mode differentiation, 438 
reward, 18 

reward function, 273 

reward hacking, 28 

RFF, 582 

ridge regression, 123, 162, 379, 453 
Riemannian manifold, 687 
Riemannian metric, 687 

right pseudo inverse, 266 

risk, 167, 188 

risk averse, 169, 170 

risk neutral, 169 

risk sensitive, 169 

RL, 17 

RLHF, 540 

RMSE, 115, 379 

RNN, 424, 501 
Robbins-Monro conditions, 293 
robust, 9, 61, 177 

robust linear regression, 304 
robust logistic regression, 358 
robustness, 400 

ROC, 173 

root mean squared error, 115, 379 
ROPE, 186 

rotation matrix, 239 

row rank, 235 

row-major order, 229 

RSS, 115 

rule of iterated expectation, 104 
rule of total probability, 38 
running sum, 119 

rv, 35 

RVM, 596 


saddle point, 275, 278 
SAGA, 296, 344 

same convolution, 469 
SAMME, 613 

Sammon mapping, 692 
sample efficiency, 17 
sample mean, 160 
sample size, 2, 110, 130 
sample space, 35 
sample variance, 143 
sampling distribution, 154 
SARS-CoV-2, 46 
saturated model, 418 
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saturates, 427 

scalar field, 268 

scalar product, 240 

scalars, 230 

scale of evidence, 180 

scaled dot-product attention, 519 
scatter matrix, 113, 244 
Schatten p-norm, 233 

scheduled sampling, 508 

Schur complement, 84, 248 

score function, 155, 273, 680 
scree plot, 662 

second order, 344 

Second-order, 287 

self attention, 524 
self-normalization property, 709 
self-supervised, 630 
self-supervised learning, 16 
self-training, 636, 648 

SELU, 448 

semantic role labeling, 355 
semantic segmentation, 489 
semantic textual similarity, 523 
semi-hard negatives, 555 
semi-supervised embeddings, 766 
Semi-supervised learning, 636 
semi-supervised learning, 335 
semidefinite embedding, 694 
semidefinite programming, 550, 695 
sensible PCA, 666 

sensitivity, 46, 172 

sensor fusion, 92 

sentiment analysis, 22 

seq2seq, 505 

seq2seq model, 22 

seq2vec, 504 

sequence logo, 207 

sequence motif, 207 

sequential minimal optimization, 586 
SGD, 290 

SGNS, 707 

shaded nodes, 102 

shallow parsing, 539 

Shampoo, 299 
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structural risk minimization, 194 
structured data, 423 

STS Benchmark, 523 

STSB, 541 

Student distribution, 61 
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